Compare commits

..

49 Commits

Author SHA1 Message Date
pablonyx
53ab977ffa slack / team improvements 2025-02-19 18:23:44 -08:00
pablonyx
33f08ff489 k 2025-02-19 18:23:44 -08:00
pablonyx
2c0ef9dfcc add hubspot fix 2025-02-19 18:23:44 -08:00
pablonyx
d7152b540a simplify slack check 2025-02-19 18:23:44 -08:00
pablonyx
cac503bf11 add slack validation 2025-02-19 18:23:44 -08:00
pablonyx
c5fdd201e1 k 2025-02-19 18:23:44 -08:00
rkuo-danswer
8b4413694a fix usage of tenant_id (#4062)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-19 17:50:58 -08:00
pablonyx
57cf7d9fac default agent search on 2025-02-19 17:21:26 -08:00
Chris Weaver
ad4efb5f20 Pin xmlsec version + improve SAML flow (#4054)
* Pin xmlsec version

* testing

* test nginx conf change

* Pass through more

* Cleanup + remove DOMAIN across the board
2025-02-19 16:02:05 -08:00
evan-danswer
e304ec4ab6 Agent search history displayed answer (#4052) 2025-02-19 15:52:16 -08:00
joachim-danswer
1690dc45ba timout bumps (#4057) 2025-02-19 15:51:45 -08:00
pablonyx
7582ba1640 Fix streaming (#4055) 2025-02-19 15:23:40 -08:00
pablonyx
99fc546943 Miscellaneous indexing fixes (#4042) 2025-02-19 11:34:49 -08:00
pablonyx
353c185856 Update error class (#4006) 2025-02-19 10:52:23 -08:00
pablonyx
7c96b7f24e minor alembic nit 2025-02-19 10:47:33 -08:00
pablonyx
31524a3eff add connector validation (#4016) 2025-02-19 10:46:06 -08:00
rkuo-danswer
c9f618798e support scrolling before scraping (#4040)
* support scrolling before scraping

* fix mypy

* install playwright deps

---------

Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-19 17:54:58 +00:00
rkuo-danswer
11f6b44625 Feature/indexing hard timeout 3 (#3980)
* WIP

* implement hard timeout

* fix callbacks

* put back the timeout

* missed a file

* fixes

* try installing playwright deps

* Revert "try installing playwright deps"

This reverts commit 4217427568.

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-19 04:12:13 +00:00
pablonyx
e82a25f49e Non-SMTP password reset (#4031)
* update

* validate

* k

* minor cleanup

* nit

* finalize

* k

* fix tests

* fix tests

* fix tests
2025-02-19 02:02:28 +00:00
Weves
5a9ec61446 Don't pass thorugh parallel_tool_calls for o-family models 2025-02-18 18:57:05 -08:00
pablonyx
9635522de8 Admin default (#4032)
* clean up

* minor cleanup

* building

* update agnetic message look

* k

* fix alembic history
2025-02-18 18:31:54 -08:00
Yuhong Sun
630bdf71a3 Update README (#4044) 2025-02-18 18:31:28 -08:00
pablonyx
47fd4fa233 Strict Tenant ID Enforcement (#3871)
* strict tenant id enforcement

* k

* k

* nit

* merge

* nit

* k
2025-02-19 00:52:56 +00:00
Weves
2013beb9e0 Adjust behavior when display_model_names is null 2025-02-18 16:19:08 -08:00
pablonyx
466276161c Quick link fix (#4039) 2025-02-18 16:18:41 -08:00
rkuo-danswer
c934892c68 add index to document__tag.tag_id (#4038)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-02-18 19:51:36 +00:00
joachim-danswer
1daa3a663d timout bumps (#4037) 2025-02-18 18:26:29 +00:00
Chris Weaver
7324273233 Small confluence group sync tweaks (#4033) 2025-02-18 07:05:41 +00:00
evan-danswer
2b2ba5478c new is_agentic flag for chatmessages (#4026)
* new is_agentic flag for chatmessages

* added cancelled error to db

* added cancelled error to returned message
2025-02-18 04:20:33 +00:00
pablonyx
045a41d929 Add default slack bot disabling (#3935)
* add slack bot disabling

* update

* k

* minor
2025-02-18 04:08:33 +00:00
pablonyx
e3bc7cc747 improve validation schema (#3984) 2025-02-18 03:18:23 +00:00
evan-danswer
0826b035a2 Update README.md (#3908)
* Update README.md

help future integration test runners

* Update README.md

* Update README.md

---------

Co-authored-by: pablonyx <pablo@danswer.ai>
2025-02-18 03:08:47 +00:00
pablonyx
cf0e3d1ff4 fix main 2025-02-17 18:23:15 -08:00
evan-danswer
10c81f75e2 consistent refined answer improvement (#4027) 2025-02-17 21:02:03 +00:00
evan-danswer
5ca898bde2 Force use tool overrides (#4024)
* initial rename + timeout bump

* querry override
2025-02-17 21:01:24 +00:00
pablonyx
58b252727f UX (#4014) 2025-02-17 13:21:43 -08:00
joachim-danswer
86bd121806 no reranking if local model w/o GPU for Agent Search (#4011)
* no reranking if locql model w/o GPU

* more efficient gpu status calling

* fix unit tests

---------

Co-authored-by: Evan Lohn <evan@danswer.ai>
2025-02-17 14:13:24 +00:00
evan-danswer
9324f426c0 added timeouts for agent llm calls (#4019)
* added timeouts for agent llm calls

* timing suggestions in agent config

* improved timeout that actually exits early

* added new global timeout and connection timeout distinction

* fixed error raising bug and made entity extraction recoverable

* warnings and refactor

* mypy

---------

Co-authored-by: joachim-danswer <joachim@danswer.ai>
2025-02-17 07:02:19 +00:00
joachim-danswer
20d3efc86e By default, use primary LLM for initial & refined answer (#4012)
* By default, use primary LLM for initial & refined answer

Use of new env variable

* simplification
2025-02-16 23:20:07 +00:00
pablonyx
ec0e55fd39 Seeding count issue (#4009)
* k

* k

* quick nit

* nit
2025-02-16 20:49:25 +00:00
pablonyx
e441c899af Playwright + Chromatic update (#4015) 2025-02-16 13:03:45 -08:00
Chris Weaver
f1fc8ac19b Connector checkpointing (#3876)
* wip checkpointing/continue on failure

more stuff for checkpointing

Basic implementation

FE stuff

More checkpointing/failure handling

rebase

rebase

initial scaffolding for IT

IT to test checkpointing

Cleanup

cleanup

Fix it

Rebase

Add todo

Fix actions IT

Test more

Pagination + fixes + cleanup

Fix IT networking

fix it

* rebase

* Address misc comments

* Address comments

* Remove unused router

* rebase

* Fix mypy

* Fixes

* fix it

* Fix tests

* Add drop index

* Add retries

* reset lock timeout

* Try hard drop of schema

* Add timeout/retries to downgrade

* rebase

* test

* test

* test

* Close all connections

* test closing idle only

* Fix it

* fix

* try using null pool

* Test

* fix

* rebase

* log

* Fix

* apply null pool

* Fix other test

* Fix quality checks

* Test not using the fixture

* Fix ordering

* fix test

* Change pooling behavior
2025-02-16 02:34:39 +00:00
Weves
bc087fc20e Fix ruff 2025-02-15 16:35:15 -08:00
Yuhong Sun
ab8081c36b k 2025-02-15 13:42:43 -08:00
Adam Siemiginowski
f371efc916 Fix Zulip connector schema + links and enable temporal metadata (#4005) 2025-02-15 11:49:41 -08:00
pablonyx
7fd5d31dbe Minor background process log cleanup (#4010) 2025-02-15 11:03:10 -08:00
rkuo-danswer
2829e6715e Feature/propagate exceptions (#3974)
* better propagation of exceptions up the stack

* remove debug testing

* refactor the watchdog more to emit data consistently at the end of the function

* enumerate a lot more terminal statuses

* handle more codes

* improve logging

* handle "-9"

* single line exception logging

* typo/grammar

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-02-15 04:53:01 +00:00
Weves
bc7b4ec396 Fix typing for metadata 2025-02-14 18:19:37 -08:00
pablonyx
697f8bc1c6 Reduce background errors (#4004) 2025-02-14 17:35:26 -08:00
272 changed files with 7719 additions and 2528 deletions

View File

@@ -99,7 +99,7 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
@@ -108,12 +108,13 @@ jobs:
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -143,24 +144,27 @@ jobs:
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f danswer-stack-api_server-1 &
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -190,15 +194,24 @@ jobs:
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network danswer-stack_default \
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -208,6 +221,8 @@ jobs:
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
@@ -229,13 +244,13 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
@@ -249,4 +264,4 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -1,6 +1,6 @@
name: Run Chromatic Tests
name: Run Playwright Tests
concurrency:
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
@@ -198,43 +198,47 @@ jobs:
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
chromatic-tests:
name: Chromatic Tests
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
needs: playwright-tests
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
# chromatic-tests:
# name: Chromatic Tests
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Install node dependencies
working-directory: ./web
run: npm ci
# - name: Setup node
# uses: actions/setup-node@v4
# with:
# node-version: 22
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results
# - name: Download Playwright test results
# uses: actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -74,7 +74,9 @@ jobs:
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -205,7 +205,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
],
"presentation": {
"group": "2",

123
README.md
View File

@@ -24,113 +24,94 @@
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
configuring AI Assistants.
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
supported?" or "Where's the pull request for feature Y?"
<h3>Usage</h3>
<h3>Feature Showcase</h3>
Onyx Web App:
**Deep research over your team's knowledge:**
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
**Use Onyx as a secure AI Chat with any LLM:**
![Onyx Chat Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxChatSilentDemo.gif)
**Easily set up connectors to your apps:**
![Onyx Connector Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxConnectorSilentDemo.gif)
**Access Onyx where your team already works:**
![Onyx Bot Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxBot.png)
For more details on the Admin UI to manage connectors and users, check out our
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
## Deployment
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
We also have built-in support for high-availability/scalable deployment on Kubernetes.
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
## 💃 Main Features
- Chat UI with the ability to select documents to chat with.
- Create custom AI Assistants with different prompts and backing knowledge sets.
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
- Document Search + AI Answers for natural language queries.
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
- Slack integration to get answers and search results directly in Slack.
## 🚧 Roadmap
- Chat/Prompt sharing with specific teammates and user groups.
- Multimodal model support, chat with images, video etc.
- Choosing between LLMs and parameters during chat session.
- Tool calling and agent configurations options.
- Extensions to the Chrome Plugin
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
## Other Notable Benefits of Onyx
- User Authentication with document level access management.
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
- Admin Dashboard to configure connectors, document-sets, access, etc.
- Custom deep learning models + learn from user feedback.
- Easy deployment and ability to host Onyx anywhere of your choosing.
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models only through Onyx + learn from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:
Efficiently pulls the latest changes from:
- Slack
- GitHub
- Google Drive
- Confluence
- Slack
- Gmail
- Salesforce
- Microsoft Sharepoint
- Github
- Jira
- Zendesk
- Gmail
- Notion
- Gong
- Slab
- Linear
- Productboard
- Guru
- Bookstack
- Document360
- Sharepoint
- Hubspot
- Microsoft Teams
- Dropbox
- Local Files
- Websites
- And more ...
## 📚 Editions
See the full list [here](https://docs.onyx.app/connectors).
## 📚 Licensing
There are two editions of Onyx:
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
- Single Sign-On (SSO), with support for both SAML and OIDC
- Role-based access control
- Document permission inheritance from connected sources
- Usage analytics and query history accessible to admins
- Whitelabeling
- API key authentication
- Encryption of secrets
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
For feature details, check out [our website](https://www.onyx.app/pricing).
To try the Onyx Enterprise Edition:
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)

View File

@@ -28,11 +28,11 @@ RUN apt-get update && \
curl \
zip \
ca-certificates \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 \
libgnutls30 \
libblkid1 \
libmount1 \
libsmartcols1 \
libuuid1 \
libxmlsec1-dev \
pkg-config \
gcc \

View File

@@ -0,0 +1,27 @@
"""Add indexes to document__tag
Revision ID: 1a03d2c2856b
Revises: 9c00a2bccb83
Create Date: 2025-02-18 10:45:13.957807
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1a03d2c2856b"
down_revision = "9c00a2bccb83"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
op.f("ix_document__tag_tag_id"),
"document__tag",
["tag_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")

View File

@@ -0,0 +1,43 @@
"""chat_message_agentic
Revision ID: 9c00a2bccb83
Revises: b7a7eee5aa15
Create Date: 2025-02-17 11:15:43.081150
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c00a2bccb83"
down_revision = "b7a7eee5aa15"
branch_labels = None
depends_on = None
def upgrade() -> None:
# First add the column as nullable
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
# Update existing rows based on presence of SubQuestions
op.execute(
"""
UPDATE chat_message
SET is_agentic = EXISTS (
SELECT 1
FROM agent__sub_question
WHERE agent__sub_question.primary_question_id = chat_message.id
)
WHERE is_agentic IS NULL
"""
)
# Make the column non-nullable with a default value of False
op.alter_column(
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
)
def downgrade() -> None:
op.drop_column("chat_message", "is_agentic")

View File

@@ -0,0 +1,29 @@
"""remove inactive ccpair status on downgrade
Revision ID: acaab4ef4507
Revises: b388730a2899
Create Date: 2025-02-16 18:21:41.330212
"""
from alembic import op
from onyx.db.models import ConnectorCredentialPair
from onyx.db.enums import ConnectorCredentialPairStatus
from sqlalchemy import update
# revision identifiers, used by Alembic.
revision = "acaab4ef4507"
down_revision = "b388730a2899"
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
op.execute(
update(ConnectorCredentialPair)
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
.values(status=ConnectorCredentialPairStatus.ACTIVE)
)

View File

@@ -0,0 +1,31 @@
"""nullable preferences
Revision ID: b388730a2899
Revises: 1a03d2c2856b
Create Date: 2025-02-17 18:49:22.643902
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b388730a2899"
down_revision = "1a03d2c2856b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("user", "temperature_override_enabled", nullable=True)
op.alter_column("user", "auto_scroll", nullable=True)
def downgrade() -> None:
# Ensure no null values before making columns non-nullable
op.execute(
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
)
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
op.alter_column("user", "temperature_override_enabled", nullable=False)
op.alter_column("user", "auto_scroll", nullable=False)

View File

@@ -0,0 +1,124 @@
"""Add checkpointing/failure handling
Revision ID: b7a7eee5aa15
Revises: f39c5794c10a
Create Date: 2025-01-24 15:17:36.763172
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b7a7eee5aa15"
down_revision = "f39c5794c10a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"index_attempt",
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
)
op.create_index(
"ix_index_attempt_cc_pair_settings_poll",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
"status",
sa.text("time_updated DESC"),
],
)
# Drop the old IndexAttemptError table
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# Create the new version of the table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=True),
sa.Column("document_link", sa.String(), nullable=True),
sa.Column("entity_id", sa.String(), nullable=True),
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
sa.Column("failure_message", sa.Text(), nullable=False),
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
)
def downgrade() -> None:
op.execute("SET lock_timeout = '5s'")
# try a few times to drop the table, this has been observed to fail due to other locks
# blocking the drop
NUM_TRIES = 10
for i in range(NUM_TRIES):
try:
op.drop_table("index_attempt_errors")
break
except Exception as e:
if i == NUM_TRIES - 1:
raise e
print(f"Error dropping table: {e}. Retrying...")
op.execute("SET lock_timeout = DEFAULT")
# Recreate the old IndexAttemptError table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
)
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
op.drop_column("index_attempt", "checkpoint_pointer")
op.drop_column("index_attempt", "poll_range_start")
op.drop_column("index_attempt", "poll_range_end")

View File

@@ -21,7 +21,7 @@ logger = setup_logger()
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@@ -44,7 +44,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
@@ -62,7 +62,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
)
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,

View File

@@ -14,30 +14,24 @@ def _build_group_member_email_map(
confluence_client: OnyxConfluence, cc_pair_id: int
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
for user in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user}")
user = user_result.get("user", {})
if not user:
msg = f"user result missing user field: {user_result}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
continue
email = user.get("email")
email = user.email
if not email:
# This field is only present in Confluence Server
user_name = user.get("username")
user_name = user.username
# If it is present, try to get the email using a Server-specific method
if user_name:
email = get_user_email_from_username__server(
confluence_client=confluence_client,
user_name=user_name,
)
if not email:
# If we still don't have an email, skip this user
msg = f"user result missing email field: {user_result}"
if user.get("type") == "app":
msg = f"user result missing email field: {user}"
if user.type == "app":
logger.warning(msg)
else:
emit_background_error(msg, cc_pair_id=cc_pair_id)
@@ -45,7 +39,7 @@ def _build_group_member_email_map(
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user):
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)

View File

@@ -5,7 +5,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)

View File

@@ -33,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
return await call_next(request)
except Exception as e:
logger.error(f"Error in tenant ID middleware: {str(e)}")
logger.exception(f"Error in tenant ID middleware: {str(e)}")
raise
@@ -49,7 +49,7 @@ async def _get_tenant_id_from_request(
"""
# Check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id:
if tenant_id is not None:
return tenant_id
# Check for anonymous user cookie

View File

@@ -36,12 +36,12 @@ from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -271,12 +271,12 @@ def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
tenant_id = get_current_tenant_id()
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
@@ -329,7 +329,6 @@ def handle_slack_oauth_callback(
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
@@ -337,7 +336,7 @@ def handle_slack_oauth_callback(
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
@@ -523,7 +522,6 @@ def handle_google_drive_oauth_callback(
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
@@ -531,7 +529,7 @@ def handle_google_drive_oauth_callback(
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (

View File

@@ -28,7 +28,7 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global(tenant_id)
@@ -52,8 +52,8 @@ User rate limits
"""
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
@@ -94,7 +94,7 @@ User Group rate limits
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:

View File

@@ -41,14 +41,15 @@ from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
@@ -57,13 +58,14 @@ router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
tenant_id: str | None = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@@ -72,15 +74,15 @@ async def get_anonymous_user_path_api(
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
tenant_id: str = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
@@ -101,7 +103,7 @@ async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_tenant(tenant_id=None) as db_session:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
@@ -150,14 +152,17 @@ async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
# Fetch tenant_id and current tenant's information
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
@@ -181,6 +186,8 @@ async def create_subscription_session(
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
@@ -197,7 +204,7 @@ async def impersonate_user(
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id) as tenant_session:
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
@@ -221,8 +228,9 @@ async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"

View File

@@ -118,7 +118,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
@@ -134,7 +134,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,

View File

@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()

View File

@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
)
graph.add_node(
node="llm_tool_choice",
action=llm_tool_choice,
node="choose_tool",
action=choose_tool,
)
graph.add_node(
node="tool_call",
action=tool_call,
node="call_tool",
action=call_tool,
)
graph.add_node(
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
graph.add_edge(
start_key="tool_call",
start_key="call_tool",
end_key="basic_use_tool_response",
)
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
else "tool_call"
else "call_tool"
)

View File

@@ -31,12 +31,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -85,9 +87,11 @@ def check_sub_answer(
agent_error: AgentErrorLog | None = None
response: BaseMessage | None = None
try:
response = fast_llm.invoke(
response = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
)
quality_str: str = cast(str, response.content)
@@ -96,7 +100,7 @@ def check_sub_answer(
)
log_result = f"Answer quality: {quality_str}"
except LLMTimeoutError:
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import merge_message_runs
@@ -47,11 +46,13 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -110,15 +111,14 @@ def generate_sub_answer(
config=fast_llm.config,
)
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
response: list[str] = []
try:
def stream_sub_answer() -> list[str]:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -142,8 +142,15 @@ def generate_sub_answer(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
except LLMTimeoutError:
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
stream_sub_answer,
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
@@ -60,11 +59,15 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -77,6 +80,7 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
_llm_node_error_strings = LLMNodeErrorStrings(
@@ -230,7 +234,11 @@ def generate_initial_answer(
sub_questions = all_sub_questions # Replace the original assignment
model = graph_config.tooling.fast_llm
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
doc_context = format_docs(answer_generation_documents.context_documents)
doc_context = trim_prompt_piece(
@@ -260,15 +268,16 @@ def generate_initial_answer(
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
streamed_tokens: list[str] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
try:
def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -292,9 +301,16 @@ def generate_initial_answer(
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
response.append(content)
return response
except LLMTimeoutError:
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
stream_initial_answer,
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -36,7 +36,10 @@ from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -47,6 +50,7 @@ from onyx.prompts.agent_search import (
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -131,10 +135,12 @@ def decompose_orig_question(
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
dispatch_separated,
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
@@ -154,7 +160,7 @@ def decompose_orig_question(
)
write_custom_event("stream_finished", stop_event, writer)
except LLMTimeoutError as e:
except (LLMTimeoutError, TimeoutError) as e:
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:

View File

@@ -25,7 +25,7 @@ logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
"""
LangGraph edge to route to agent search.
"""
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
):
return "start_agent_search"
else:
return "tool_call"
return "call_tool"
else:
return "logging_node"

View File

@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
# Choose the initial tool
graph.add_node(
node="initial_tool_choice",
action=llm_tool_choice,
action=choose_tool,
)
# Call the tool, if required
graph.add_node(
node="tool_call",
action=tool_call,
node="call_tool",
action=call_tool,
)
# Use the tool response
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_conditional_edges(
"initial_tool_choice",
route_initial_tool_choice,
["tool_call", "start_agent_search", "logging_node"],
["call_tool", "start_agent_search", "logging_node"],
)
graph.add_edge(
start_key="tool_call",
start_key="call_tool",
end_key="basic_use_tool_response",
)
graph.add_edge(

View File

@@ -33,13 +33,15 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -105,11 +107,14 @@ def compare_answers(
refined_answer_improvement: bool | None = None
# no need to stream this
try:
resp = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
resp = run_with_timeout(
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
)
except LLMTimeoutError:
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -44,7 +44,10 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -53,6 +56,7 @@ from onyx.prompts.agent_search import (
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -134,15 +138,17 @@ def create_refined_sub_questions(
agent_error: AgentErrorLog | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = dispatch_separated(
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
dispatch_separated,
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except LLMTimeoutError:
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -22,11 +22,17 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
@@ -84,30 +90,42 @@ def extract_entities_terms(
]
fast_llm = graph_config.tooling.fast_llm
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
llm_response = run_with_timeout(
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
except ValueError:
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - extract entities terms")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - extract entities terms")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
return EntityTermExtractionUpdate(

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
@@ -66,14 +65,21 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -92,6 +98,7 @@ from onyx.prompts.agent_search import (
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -253,7 +260,12 @@ def generate_validate_refined_answer(
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
model = graph_config.tooling.fast_llm
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
relevant_docs_str = trim_prompt_piece(
model.config,
@@ -284,13 +296,13 @@ def generate_validate_refined_answer(
)
]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
streamed_tokens: list[str] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
try:
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -315,8 +327,15 @@ def generate_validate_refined_answer(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
return streamed_tokens
except LLMTimeoutError:
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
stream_refined_answer,
)
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
@@ -383,16 +402,20 @@ def generate_validate_refined_answer(
)
]
validation_model = graph_config.tooling.fast_llm
try:
validation_response = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
validation_response = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
separator=AGENT_ANSWER_SEPARATOR,
)
except LLMTimeoutError:
except (LLMTimeoutError, TimeoutError):
refined_answer_quality = True
logger.error("LLM Timeout Error - validate refined answer")

View File

@@ -34,14 +34,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -69,7 +71,7 @@ def expand_queries(
node_start_time = datetime.now()
question = state.question
llm = graph_config.tooling.fast_llm
model = graph_config.tooling.fast_llm
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_num = 0, 0
@@ -88,10 +90,12 @@ def expand_queries(
rewritten_queries = []
try:
llm_response_list = dispatch_separated(
llm.stream(
llm_response_list = run_with_timeout(
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
dispatch_separated,
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
@@ -101,7 +105,7 @@ def expand_queries(
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
except LLMTimeoutError:
except (LLMTimeoutError, TimeoutError):
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -55,6 +55,7 @@ def rerank_documents(
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
if rerank_settings is None:
with get_session_context_manager() as db_session:
@@ -62,23 +63,31 @@ def rerank_documents(
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
# Initial default: no reranking. Will be overwritten below if reranking is warranted
reranked_documents = verified_documents
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1:
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
if not allow_agent_reranking:
logger.info("Use of local rerank model without GPU, skipping reranking")
# No reranking, stay with verified_documents as default
else:
# Reranking is warranted, use the rerank_sections functon
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
else:
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
reranked_documents = verified_documents
# No reranking, stay with verified_documents as default
else:
logger.warning("No reranking settings found, using unranked documents")
reranked_documents = verified_documents
# No reranking, stay with verified_documents as default
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:

View File

@@ -25,13 +25,15 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -86,8 +88,11 @@ def verify_documents(
] # default is to treat document as relevant
try:
response = fast_llm.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
response = run_with_timeout(
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
)
assert isinstance(response.content, str)
@@ -96,7 +101,7 @@ def verify_documents(
):
verified_documents = []
except LLMTimeoutError:
except (LLMTimeoutError, TimeoutError):
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
logger.error("LLM Timeout Error - verify documents")

View File

@@ -67,6 +67,7 @@ class GraphSearchConfig(BaseModel):
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
class GraphConfig(BaseModel):

View File

@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
write_custom_event("basic_response", packet, writer)
def tool_call(
def call_tool(
state: ToolChoiceUpdate,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,

View File

@@ -25,7 +25,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def llm_tool_choice(
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,

View File

@@ -43,8 +43,9 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
@@ -80,6 +81,7 @@ from onyx.tools.tool_implementations.search.search_tool import SearchResponseSum
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
@@ -395,11 +397,13 @@ def summarize_history(
)
try:
history_response = llm.invoke(
history_response = run_with_timeout(
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
llm.invoke,
history_context_prompt,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
except LLMTimeoutError:
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this

View File

@@ -42,4 +42,5 @@ def fetch_no_auth_user(
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
is_anonymous_user=anonymous_user_enabled,
password_configured=False,
)

View File

@@ -1,5 +1,7 @@
import json
import random
import secrets
import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
@@ -86,7 +88,6 @@ from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -104,6 +105,7 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -139,6 +141,30 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
return email or ""
def generate_password() -> str:
lowercase_letters = string.ascii_lowercase
uppercase_letters = string.ascii_uppercase
digits = string.digits
special_characters = string.punctuation
# Ensure at least one of each required character type
password = [
secrets.choice(uppercase_letters),
secrets.choice(digits),
secrets.choice(special_characters),
]
# Fill the rest with a mix of characters
remaining_length = 12 - len(password)
all_characters = lowercase_letters + uppercase_letters + digits + special_characters
password.extend(secrets.choice(all_characters) for _ in range(remaining_length))
# Shuffle the password to randomize the position of the required characters
random.shuffle(password)
return "".join(password)
def user_needs_to_be_verified() -> bool:
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
return REQUIRE_EMAIL_VERIFICATION
@@ -189,7 +215,7 @@ def verify_email_is_invited(email: str) -> None:
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -591,6 +617,39 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def reset_password_as_admin(self, user_id: uuid.UUID) -> str:
"""Admin-only. Generate a random password for a user and return it."""
user = await self.get(user_id)
new_password = generate_password()
await self._update(user, {"password": new_password})
return new_password
async def change_password_if_old_matches(
self, user: User, old_password: str, new_password: str
) -> None:
"""
For normal users to change password if they know the old one.
Raises 400 if old password doesn't match.
"""
verified, updated_password_hash = self.password_helper.verify_and_update(
old_password, user.hashed_password
)
if not verified:
# Raise some HTTPException (or your custom exception) if old password is invalid:
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid current password",
)
# If the hash was upgraded behind the scenes, we can keep it before setting the new password:
if updated_password_hash:
user.hashed_password = updated_password_hash
# Now apply and validate the new password
await self._update(user, {"password": new_password})
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
@@ -815,8 +874,9 @@ async def current_limited_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> User | None:
tenant_id = get_current_tenant_id()
return await double_check_user(
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
)

View File

@@ -33,6 +33,7 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -58,13 +59,35 @@ else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
class TenantAwareTask(Task):
"""A custom base Task that sets tenant_id in a contextvar before running."""
abstract = True # So Celery knows not to register this as a real task.
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Grab tenant_id from the kwargs, or fallback to default if missing.
tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA
# Set the context var
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Actually run the task now
try:
return super().__call__(*args, **kwargs)
finally:
# Clear or reset after the task runs
# so it does not leak into any subsequent tasks on the same worker process
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**kwds: Any,
**other_kwargs: Any,
) -> None:
pass
@@ -201,7 +224,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_redis_client(tenant_id=None)
r = get_shared_redis_client()
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -287,7 +310,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_redis_client(tenant_id=None)
r = get_shared_redis_client()
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
@@ -439,24 +462,6 @@ class TenantContextFilter(logging.Filter):
return True
@task_prerun.connect
def set_tenant_id(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
"""Signal handler to set tenant ID in context var before task starts."""
tenant_id = (
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if kwargs
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@task_postrun.connect
def reset_tenant_id(
sender: Any | None = None,

View File

@@ -132,6 +132,7 @@ class DynamicTenantScheduler(PersistentScheduler):
f"Adding options to task {tenant_task_name}: {options}"
)
tenant_task["options"] = options
new_schedule[tenant_task_name] = tenant_task
return new_schedule
@@ -256,3 +257,4 @@ def on_setup_logging(
celery_app.conf.beat_scheduler = DynamicTenantScheduler
celery_app.conf.task_default_base = app_base.TenantAwareTask

View File

@@ -20,6 +20,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.heavy")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -21,6 +21,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.indexing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -23,6 +23,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.light")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -20,6 +20,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -24,7 +24,7 @@ from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from onyx.db.engine import get_session_with_default_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
@@ -38,7 +38,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -47,6 +47,7 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.primary")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -101,7 +102,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_redis_client(tenant_id=None)
r = get_shared_redis_client()
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -158,7 +159,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_default_tenant() as db_session:
with get_session_with_current_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
@@ -234,7 +235,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_redis_client(tenant_id=None)
r = get_shared_redis_client()
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")

View File

@@ -36,6 +36,15 @@ beat_task_templates.extend(
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-checkpoint-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,

View File

@@ -27,7 +27,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -62,8 +62,8 @@ class TaskDependencyError(RuntimeError):
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
@@ -77,14 +77,14 @@ def check_for_connector_deletion_task(
try:
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
cc_pair_ids.append(cc_pair.id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
@@ -277,7 +277,7 @@ def monitor_connector_deletion_taskset(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -287,7 +287,7 @@ def monitor_connector_deletion_taskset(
)
return
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,

View File

@@ -45,7 +45,7 @@ from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -119,13 +119,13 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None:
# TODO(rkuo): merge into check function after lookup table for fences is added
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -140,7 +140,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
try:
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
@@ -189,7 +189,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
@@ -247,7 +247,7 @@ def try_creating_permissions_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
@@ -321,7 +321,7 @@ def connector_permission_sync_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -378,7 +378,7 @@ def connector_permission_sync_generator_task(
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -480,7 +480,8 @@ def update_external_document_permissions_task(
external_access = document_external_access.external_access
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Add the users to the DB if they don't exist
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),

View File

@@ -39,7 +39,7 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -122,8 +122,8 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -140,7 +140,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
try:
cc_pair_ids_to_sync: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
@@ -230,7 +230,7 @@ def try_creating_external_group_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
@@ -296,7 +296,7 @@ def connector_external_group_sync_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -357,7 +357,7 @@ def connector_external_group_sync_generator_task(
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -384,6 +384,7 @@ def connector_external_group_sync_generator_task(
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
logger.debug(f"New external user groups: {external_user_groups}")
replace_user__ext_group_for_cc_pair(
db_session=db_session,
@@ -408,7 +409,7 @@ def connector_external_group_sync_generator_task(
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -459,7 +460,6 @@ def validate_external_group_sync_fences(
)
lock_beat.reacquire()
return

View File

@@ -1,9 +1,10 @@
import multiprocessing
import os
import sys
import time
import traceback
from datetime import datetime
from datetime import timezone
from enum import Enum
from http import HTTPStatus
from time import sleep
from typing import Any
@@ -15,6 +16,7 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -26,22 +28,31 @@ from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attem
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
from onyx.background.indexing.checkpointing_utils import (
get_index_attempts_with_old_checkpoints,
)
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import IndexingMode
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
@@ -70,6 +81,134 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
SPAWN_NOT_ALIVE = (
"spawn_not_alive" # spawn succeeded but process did not come alive
)
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
FENCE_READINESS_TIMEOUT = (
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
)
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
INDEX_ATTEMPT_MISMATCH = (
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_VALIDATION_ERROR = (
"connector_validation_error" # the connector validation failed
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
# consolidate once we know more
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@property
def code(self) -> int:
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
}
return _ENUM_TO_CODE[self]
@classmethod
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
}
if code in _CODE_TO_ENUM:
return _CODE_TO_ENUM[code]
return IndexingWatchdogTerminalStatus.UNDEFINED
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None
class ConnectorIndexingContext(BaseModel):
tenant_id: str | None
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class ConnectorIndexingLogBuilder:
def __init__(self, ctx: ConnectorIndexingContext):
self.ctx = ctx
def build(self, msg: str, **kwargs: Any) -> str:
msg_final = (
f"{msg}: "
f"tenant_id={self.ctx.tenant_id} "
f"attempt={self.ctx.index_attempt_id} "
f"cc_pair={self.ctx.cc_pair_id} "
f"search_settings={self.ctx.search_settings_id}"
)
# Append extra keyword arguments in logfmt style
if kwargs:
extra_logfmt = " ".join(f"{key}={value}" for key, value in kwargs.items())
msg_final = f"{msg_final} {extra_logfmt}"
return msg_final
def monitor_ccpair_indexing_taskset(
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
@@ -222,12 +361,13 @@ def monitor_ccpair_indexing_taskset(
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
tasks_created = 0
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
redis_client = get_redis_client()
redis_client_replica = get_redis_replica_client()
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
@@ -265,7 +405,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# 1/3: KICKOFF
# check for search settings swap
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
@@ -286,7 +426,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# gather cc_pair_ids
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
@@ -296,7 +436,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
@@ -374,7 +514,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
@@ -426,7 +566,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_ccpair_indexing_taskset(
tenant_id, key_bytes, redis_client_replica, db_session
)
@@ -457,8 +597,8 @@ def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
tenant_id: str | None,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -496,7 +636,6 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
attempt_found = False
n_final_progress: int | None = None
# 20 is the documented default for httpx max_keepalive_connections
@@ -510,22 +649,24 @@ def connector_indexing_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
if redis_connector.delete.fenced:
raise RuntimeError(
raise SimpleJobException(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}"
f"fence={redis_connector.delete.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
)
if redis_connector.stop.fenced:
raise RuntimeError(
raise SimpleJobException(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}"
f"fence={redis_connector.stop.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
)
# this wait is needed to avoid a race condition where
@@ -534,19 +675,24 @@ def connector_indexing_task(
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
raise SimpleJobException(
f"connector_indexing_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}"
f"fence={redis_connector.permissions.fence_key}",
code=IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT.code,
)
if not redis_connector_index.fenced: # The fence must exist
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
raise SimpleJobException(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise ValueError("connector_indexing_task: payload invalid or not found")
raise SimpleJobException(
"connector_indexing_task: payload invalid or not found",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
@@ -556,10 +702,11 @@ def connector_indexing_task(
continue
if payload.index_attempt_id != index_attempt_id:
raise ValueError(
raise SimpleJobException(
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
f"task_index_attempt={index_attempt_id} "
f"payload_index_attempt={payload.index_attempt_id}"
f"payload_index_attempt={payload.index_attempt_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
)
logger.info(
@@ -583,19 +730,26 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return None
raise SimpleJobException(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
)
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise ValueError(
f"Index attempt not found: index_attempt={index_attempt_id}"
raise SimpleJobException(
f"Index attempt not found: index_attempt={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
attempt_found = True
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
@@ -603,25 +757,30 @@ def connector_indexing_task(
)
if not cc_pair:
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
raise SimpleJobException(
f"cc_pair not found: cc_pair={cc_pair_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
if not cc_pair.connector:
raise ValueError(
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
raise SimpleJobException(
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
if not cc_pair.credential:
raise ValueError(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
raise SimpleJobException(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector,
redis_connector_index,
lock,
r,
redis_connector_index,
)
logger.info(
@@ -643,6 +802,15 @@ def connector_indexing_task(
# get back the total number of indexed docs and return it
n_final_progress = redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
except ConnectorValidationError:
raise SimpleJobException(
f"Indexing task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
@@ -650,22 +818,8 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if attempt_found:
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(
index_attempt_id, db_session, failure_reason=str(e)
)
except Exception:
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
finally:
if lock.owned():
lock.release()
@@ -678,41 +832,49 @@ def connector_indexing_task(
return n_final_progress
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except Exception:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
if job.process:
result.exit_code = job.process.exitcode
# There is a cloud related bug outside of our code
# where spawned tasks return with an exit code of 1.
# Unfortunately, exceptions also return with an exit code of 1,
# so just raising an exception isn't informative
# Exiting with 255 makes it possible to distinguish between normal exits
# and exceptions.
sys.exit(255)
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
)
else:
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
result.exception_str = job.exception()
return result
@@ -730,12 +892,32 @@ def connector_indexing_proxy_task(
search_settings_id: int,
tenant_id: str | None,
) -> None:
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
"""
start = time.monotonic()
result = SimpleJobResult()
ctx = ConnectorIndexingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
task_logger.info(
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"mp_start_method={multiprocessing.get_start_method()}"
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
)
if not self.request.id:
@@ -744,149 +926,297 @@ def connector_indexing_proxy_task(
client = SimpleJobClient()
job = client.submit(
connector_indexing_task_wrapper,
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
global_version.is_ee_version(),
pure=False,
tenant_id,
)
if not job:
if not job or not job.process:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
return
# Ensure the process has moved out of the starting state
num_waits = 0
while True:
if num_waits > 15:
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
job.release()
return
if job.process.is_alive() or job.process.exitcode is not None:
break
sleep(1)
num_waits += 1
task_logger.info(
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
log_builder.build(
"Indexing watchdog - spawn succeeded",
pid=str(job.process.pid),
)
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
while True:
sleep(5)
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
exit_code: int | None
try:
if job.status == "error":
ignore_exitcode = False
exit_code = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates successful completion
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if not ignore_exitcode:
raise RuntimeError("Spawned task exceptioned.")
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
except Exception:
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
raise
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
redis_connector_index.set_active() # renew active signal
redis_connector_index.set_connector_active() # prime the connective active signal
while True:
sleep(5)
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
)
job.cancel()
break
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
if not redis_connector_index.connector_active():
task_logger.warning(
log_builder.build(
"Indexing watchdog - activity timeout exceeded",
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
)
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
# No need to expose full stack trace for validation errors
result.exception_str = str(e)
else:
result.exception_str = traceback.format_exc()
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
with get_session_with_current_tenant() as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
continue
normalized_exception_str = "None"
if result.exception_str:
normalized_exception_str = result.exception_str.replace(
"\n", "\\n"
).replace('"', '\\"')
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=f'"{normalized_exception_str}"',
elapsed=f"{elapsed:.2f}s",
)
)
redis_connector_index.set_watchdog(False)
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)
redis_connector_index.set_watchdog(False)
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
)
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
"""Clean up old checkpoints that are older than 7 days."""
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# these tasks should never overlap
if not lock.acquire(blocking=False):
return None
try:
locked = True
with get_session_with_current_tenant() as db_session:
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
for attempt in old_attempts:
task_logger.info(
f"Cleaning up checkpoint for index attempt {attempt.id}"
)
cleanup_checkpoint_task.apply_async(
kwargs={
"index_attempt_id": attempt.id,
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.CHECKPOINT_CLEANUP,
)
except Exception:
task_logger.exception("Unexpected exception during checkpoint cleanup")
return None
finally:
if locked:
if lock.owned():
lock.release()
else:
task_logger.error(
"check_for_checkpoint_cleanup - Lock not owned on completion: "
f"tenant={tenant_id}"
)
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,
)
def cleanup_checkpoint_task(
self: Task, *, index_attempt_id: int, tenant_id: str | None
) -> None:
"""Clean up a checkpoint for a given index attempt"""
with get_session_with_current_tenant() as db_session:
cleanup_checkpoint(db_session, index_attempt_id)

View File

@@ -23,7 +23,7 @@ from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
@@ -93,27 +93,25 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
return unfenced_attempts
class IndexingCallback(IndexingHeartbeatInterface):
class IndexingCallbackBase(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_connector_index: RedisConnectorIndex,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_connector: RedisConnector = redis_connector
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
self.redis_lock: RedisLock = redis_lock
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "IndexingCallback.__init__"
self.last_tag: str = f"{self.__class__.__name__}.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
@@ -127,8 +125,8 @@ class IndexingCallback(IndexingHeartbeatInterface):
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# with daemon=True. It seems likely some indexing tasks will need to spawn other processes
# eventually, which daemon=True prevents, so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
@@ -143,8 +141,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
# self.last_parent_check = now
try:
self.redis_connector.prune.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
@@ -156,7 +152,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
self.last_tag = tag
except LockError:
logger.exception(
f"IndexingCallback - lock.reacquire exceptioned: "
f"{self.__class__.__name__} - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
@@ -167,6 +163,24 @@ class IndexingCallback(IndexingHeartbeatInterface):
redis_lock_dump(self.redis_lock, self.redis_client)
raise
class IndexingCallback(IndexingCallbackBase):
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
redis_connector_index: RedisConnectorIndex,
):
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
def progress(self, tag: str, amount: int) -> None:
self.redis_connector_index.set_active()
self.redis_connector_index.set_connector_active()
super().progress(tag, amount)
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
@@ -240,7 +254,8 @@ def validate_indexing_fence(
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
@@ -317,7 +332,7 @@ def validate_indexing_fences(
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
continue
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,

View File

@@ -8,7 +8,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import LLMProvider
@@ -75,7 +75,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | N
return None
# Then update the database with the fetched models
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Get the default LLM provider
default_provider = (
db_session.query(LLMProvider)

View File

@@ -26,7 +26,8 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -42,7 +43,6 @@ from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
@@ -668,7 +668,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
task_logger.info("Starting background monitoring")
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
lock_monitoring: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
@@ -683,7 +683,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client(tenant_id=tenant_id)
redis_std = get_redis_client()
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
@@ -693,7 +693,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
]
# Collect and log each metric
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
@@ -771,12 +771,11 @@ def cloud_check_alembic() -> bool | None:
if tenant_id is None:
continue
with get_session_with_tenant(tenant_id=None) as session:
with get_session_with_shared_schema() as session:
try:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
if result_scalar is None:
raise ValueError("Alembic version should not be None.")

View File

@@ -15,7 +15,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PostgresAdvisoryLocks
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
@shared_task(
@@ -36,7 +36,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),

View File

@@ -21,7 +21,7 @@ from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -41,7 +41,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -62,6 +62,12 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class PruneCallback(IndexingCallbackBase):
def progress(self, tag: str, amount: int) -> None:
self.redis_connector.prune.set_active()
super().progress(tag, amount)
"""Jobs / utils for kicking off pruning tasks."""
@@ -108,8 +114,8 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -127,14 +133,14 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
cc_pair_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -182,7 +188,7 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
task_logger.info(
@@ -337,7 +343,7 @@ def connector_pruning_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client(tenant_id=tenant_id)
r = get_redis_client()
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -395,7 +401,7 @@ def connector_pruning_generator_task(
return None
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
@@ -425,6 +431,7 @@ def connector_pruning_generator_task(
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -434,12 +441,11 @@ def connector_pruning_generator_task(
)
search_settings = get_current_search_settings(db_session)
redis_connector_index = redis_connector.new_index(search_settings.id)
redis_connector.new_index(search_settings.id)
callback = IndexingCallback(
callback = PruneCallback(
0,
redis_connector,
redis_connector_index,
lock,
r,
)

View File

@@ -27,7 +27,7 @@ from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
@@ -79,7 +79,7 @@ def document_by_cc_pair_cleanup_task(
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
action = "skip"
chunks_affected = 0
@@ -105,6 +105,7 @@ def document_by_cc_pair_cleanup_task(
tenant_id=tenant_id,
chunk_count=chunk_count,
)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
@@ -204,7 +205,7 @@ def document_by_cc_pair_cleanup_task(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(

View File

@@ -34,7 +34,7 @@ from onyx.db.document_set import fetch_document_sets
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import DocumentSet
@@ -84,8 +84,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
time_start = time.monotonic()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r = get_redis_client()
r_replica = get_redis_replica_client()
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -98,7 +98,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
try:
# 1/3: KICKOFF
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
try_generate_stale_document_sync_tasks(
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
)
@@ -106,7 +106,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
# region document set scan
lock_beat.reacquire()
document_set_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
@@ -117,7 +117,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
for document_set_id in document_set_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
try_generate_document_set_sync_tasks(
self.app, document_set_id, db_session, r, lock_beat, tenant_id
)
@@ -136,7 +136,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
pass
else:
usergroup_ids: list[int] = []
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
@@ -146,7 +146,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
for usergroup_id in usergroup_ids:
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
try_generate_user_group_sync_tasks(
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
)
@@ -167,7 +167,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
monitor_usergroup_taskset = (
@@ -177,7 +177,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
noop_fallback,
)
)
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
@@ -523,12 +523,12 @@ def monitor_document_set_taskset(
max_retries=3,
)
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
self: Task, document_id: str, *, tenant_id: str | None
) -> bool:
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,

View File

@@ -1,5 +1,5 @@
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
def emit_background_error(
@@ -9,5 +9,5 @@ def emit_background_error(
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_tenant() as db_session:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)

View File

@@ -1,80 +0,0 @@
"""Experimental functionality related to splitting up indexing
into a series of checkpoints to better handle intermittent failures
/ jobs being killed by cloud providers."""
import datetime
from onyx.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
def _2010_dt() -> datetime.datetime:
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
def _2020_dt() -> datetime.datetime:
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
def _default_end_time(
last_successful_run: datetime.datetime | None,
) -> datetime.datetime:
"""If year is before 2010, go to the beginning of 2010.
If year is 2010-2020, go in 5 year increments.
If year > 2020, then go in 180 day increments.
For connectors that don't support a `filter_by` and instead rely on `sort_by`
for polling, then this will cause a massive duplication of fetches. For these
connectors, you may want to override this function to return a more reasonable
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
last_successful_run = (
datetime_to_utc(last_successful_run) if last_successful_run else None
)
if last_successful_run is None or last_successful_run < _2010_dt():
return _2010_dt()
if last_successful_run < _2020_dt():
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
return last_successful_run + datetime.timedelta(days=180)
def find_end_time_for_indexing_attempt(
last_successful_run: datetime.datetime | None,
# source_type can be used to override the default for certain connectors, currently unused
source_type: DocumentSource,
) -> datetime.datetime | None:
"""Is the current time unless the connector is run over a large period, in which case it is
split up into large time segments that become smaller as it approaches the present
"""
# NOTE: source_type can be used to override the default for certain connectors
end_of_window = _default_end_time(last_successful_run)
now = datetime.datetime.now(tz=datetime.timezone.utc)
if end_of_window < now:
return end_of_window
# None signals that we should index up to current time
return None
def get_time_windows_for_index_attempt(
last_successful_run: datetime.datetime, source_type: DocumentSource
) -> list[tuple[datetime.datetime, datetime.datetime]]:
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
start_of_window: datetime.datetime | None = last_successful_run
while start_of_window:
end_of_window = find_end_time_for_indexing_attempt(
last_successful_run=start_of_window, source_type=source_type
)
time_windows.append(
(
start_of_window,
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
)
)
start_of_window = end_of_window
return time_windows

View File

@@ -0,0 +1,200 @@
from datetime import datetime
from datetime import timedelta
from io import BytesIO
from sqlalchemy import and_
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import ConnectorCheckpoint
from onyx.db.engine import get_db_current_time
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from onyx.utils.object_size_check import deep_getsizeof
logger = setup_logger()
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT = 100
def _build_checkpoint_pointer(index_attempt_id: int) -> str:
return f"checkpoint_{index_attempt_id}.json"
def save_checkpoint(
db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint
) -> str:
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=checkpoint_pointer,
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
file_origin=FileOrigin.INDEXING_CHECKPOINT,
file_type="application/json",
)
index_attempt = get_index_attempt(db_session, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
index_attempt.checkpoint_pointer = checkpoint_pointer
db_session.add(index_attempt)
db_session.commit()
return checkpoint_pointer
def load_checkpoint(
db_session: Session, index_attempt_id: int
) -> ConnectorCheckpoint | None:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
try:
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
except RuntimeError:
return None
def get_latest_valid_checkpoint(
db_session: Session,
cc_pair_id: int,
search_settings_id: int,
window_start: datetime,
window_end: datetime,
) -> ConnectorCheckpoint:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
db_session=db_session,
limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER,
)
checkpoint_candidates = [
candidate
for candidate in checkpoint_candidates
if (
candidate.poll_range_start == window_start
and candidate.poll_range_end == window_end
and candidate.status == IndexingStatus.FAILED
and candidate.checkpoint_pointer is not None
# we want to make sure that the checkpoint is actually useful
# if it's only gone through a few docs, it's probably not worth
# using. This also avoids weird cases where a connector is basically
# non-functional but still "makes progress" by slowly moving the
# checkpoint forward run after run
and candidate.total_docs_indexed
and candidate.total_docs_indexed > _NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT
)
]
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
# for now, capped at 10
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
logger.warning(
f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts found "
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return ConnectorCheckpoint.build_dummy_checkpoint()
# assumes latest checkpoint is the furthest along. This only isn't true
# if something else has gone wrong.
latest_valid_checkpoint_candidate = (
checkpoint_candidates[0] if checkpoint_candidates else None
)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
if latest_valid_checkpoint_candidate:
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
)
except Exception:
logger.exception(
f"Failed to load checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}."
)
previous_checkpoint = None
if previous_checkpoint is not None:
logger.info(
f"Using checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
f"{previous_checkpoint}"
)
save_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
checkpoint=previous_checkpoint,
)
checkpoint = previous_checkpoint
return checkpoint
def get_index_attempts_with_old_checkpoints(
db_session: Session, days_to_keep: int = 7
) -> list[IndexAttempt]:
"""Get all index attempts with checkpoints older than the specified number of days.
Args:
db_session: The database session
days_to_keep: Number of days to keep checkpoints for (default: 7)
Returns:
Number of checkpoints deleted
"""
cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep)
# Find all index attempts with checkpoints older than cutoff_date
old_attempts = (
db_session.query(IndexAttempt)
.filter(
and_(
IndexAttempt.checkpoint_pointer.isnot(None),
IndexAttempt.time_created < cutoff_date,
)
)
.all()
)
return old_attempts
def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
"""Clean up a checkpoint for a given index attempt"""
index_attempt = get_index_attempt(db_session, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store(db_session)
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None
db_session.add(index_attempt)
db_session.commit()
return None
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
"""Check if the checkpoint content size exceeds the limit (200MB)"""
content_size = deep_getsizeof(checkpoint.checkpoint_content)
if content_size > 200_000_000: # 200MB in bytes
raise ValueError(
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
)

View File

@@ -5,6 +5,8 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
import sys
import traceback
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
@@ -18,6 +20,16 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class SimpleJobException(Exception):
"""lets us raise an exception that will return a specific error code"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
code: int | None = kwargs.pop("code", None)
self.code = code
super().__init__(*args, **kwargs)
JobStatusType = (
Literal["error"]
| Literal["finished"]
@@ -28,7 +40,10 @@ JobStatusType = (
def _initializer(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
@@ -52,13 +67,29 @@ def _initializer(
)
# Proceed with executing the target function
return func(*args, **kwargs)
try:
return func(*args, **kwargs)
except SimpleJobException as e:
logger.exception("SimpleJob raised a SimpleJobException")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(e.code) # use the given exit code
except Exception:
logger.exception("SimpleJob raised an exception")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use 255 to indicate a generic exception
def _run_in_process(
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
) -> None:
_initializer(func, args, kwargs)
_initializer(func, queue, args, kwargs)
@dataclass
@@ -67,6 +98,8 @@ class SimpleJob:
id: int
process: Optional["SpawnProcess"] = None
queue: Optional[mp.Queue] = None
_exception: Optional[str] = None
def cancel(self) -> bool:
return self.release()
@@ -100,9 +133,15 @@ class SimpleJob:
def exception(self) -> str:
"""Needed to match the Dask API, but not implemented since we don't currently
have a way to get back the exception information from the child process."""
return (
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
)
"""Retrieve exception from the multiprocessing queue if available."""
if self._exception is None and self.queue and not self.queue.empty():
self._exception = self.queue.get() # Get exception from queue
if self._exception:
return self._exception
return f"Job with ID '{self.id}' did not report an exception."
class SimpleJobClient:
@@ -137,8 +176,11 @@ class SimpleJobClient:
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
queue = ctx.Queue()
process = ctx.Process(
target=_run_in_process, args=(func, queue, args), daemon=True
)
job = SimpleJob(id=job_id, process=process, queue=queue)
process.start()
self.jobs[job_id] = job

View File

@@ -0,0 +1,87 @@
import tracemalloc
from onyx.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class MemoryTracer:
def __init__(self, interval: int = 0, num_print_entries: int = 5):
self.interval = interval
self.num_print_entries = num_print_entries
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
self.counter = 0
def start(self) -> None:
"""Start the memory tracer if interval is greater than 0."""
if self.interval > 0:
logger.debug(f"Memory tracer starting: interval={self.interval}")
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
self._take_snapshot()
def stop(self) -> None:
"""Stop the memory tracer if it's running."""
if self.interval > 0:
self.log_final_diff()
tracemalloc.stop()
logger.debug("Memory tracer stopped.")
def _take_snapshot(self) -> None:
"""Take a snapshot and update internal snapshot states."""
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__),
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<frozen importlib._bootstrap_external>"),
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def _log_diff(
self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot
) -> None:
"""Log the memory difference between two snapshots."""
stats = current.compare_to(previous, "traceback")
for s in stats[: self.num_print_entries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def increment_and_maybe_trace(self) -> None:
"""Increment counter and perform trace if interval is hit."""
if self.interval <= 0:
return
self.counter += 1
if self.counter % self.interval == 0:
logger.debug(
f"Running trace comparison for batch {self.counter}. interval={self.interval}"
)
self._take_snapshot()
if self.snapshot and self.snapshot_prev:
self._log_diff(self.snapshot, self.snapshot_prev)
def log_final_diff(self) -> None:
"""Log the final memory diff between start and end of indexing."""
if self.interval <= 0:
return
logger.debug(
f"Running trace comparison between start and end of indexing. {self.counter} batches processed."
)
self._take_snapshot()
if self.snapshot and self.snapshot_first:
self._log_diff(self.snapshot, self.snapshot_first)

View File

@@ -0,0 +1,40 @@
from datetime import datetime
from pydantic import BaseModel
from onyx.db.models import IndexAttemptError
class IndexAttemptErrorPydantic(BaseModel):
id: int
connector_credential_pair_id: int
document_id: str | None
document_link: str | None
entity_id: str | None
failed_time_range_start: datetime | None
failed_time_range_end: datetime | None
failure_message: str
is_resolved: bool = False
time_created: datetime
index_attempt_id: int
@classmethod
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
return cls(
id=model.id,
connector_credential_pair_id=model.connector_credential_pair_id,
document_id=model.document_id,
document_link=model.document_link,
entity_id=model.entity_id,
failed_time_range_start=model.failed_time_range_start,
failed_time_range_end=model.failed_time_range_end,
failure_message=model.failure_message,
is_resolved=model.is_resolved,
time_created=model.time_created,
index_attempt_id=model.index_attempt_id,
)

View File

@@ -1,5 +1,6 @@
import time
import traceback
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -7,8 +8,11 @@ from datetime import timezone
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
from onyx.background.indexing.tracer import OnyxTracer
from onyx.background.indexing.checkpointing_utils import check_checkpoint_size
from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint
from onyx.background.indexing.checkpointing_utils import save_checkpoint
from onyx.background.indexing.memory_tracer import MemoryTracer
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
@@ -17,22 +21,28 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
@@ -53,6 +63,7 @@ INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
batch_size: int,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
@@ -76,6 +87,11 @@ def _get_connector_runner(
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
# validate the connector settings
runnable_connector.validate_connector_settings()
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
@@ -100,7 +116,9 @@ def _get_connector_runner(
raise e
return ConnectorRunner(
connector=runnable_connector, time_range=(start_time, end_time)
connector=runnable_connector,
batch_size=batch_size,
time_range=(start_time, end_time),
)
@@ -159,6 +177,66 @@ class RunIndexingContext(BaseModel):
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
Raises a RuntimeError if any conditions are not met.
"""
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
if (
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_loop:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
def _check_failure_threshold(
total_failures: int,
document_count: int,
batch_num: int,
last_failure: ConnectorFailure | None,
) -> None:
"""Check if we've hit the failure threshold and raise an appropriate exception if so.
We consider the threshold hit if:
1. We have more than 3 failures AND
2. Failures account for more than 10% of processed documents
"""
failure_ratio = total_failures / (document_count or 1)
FAILURE_THRESHOLD = 3
FAILURE_RATIO_THRESHOLD = 0.1
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
logger.error(
f"Connector run failed with '{total_failures}' errors "
f"after '{batch_num}' batches."
)
if last_failure and last_failure.exception:
raise last_failure.exception from last_failure.exception
raise RuntimeError(
f"Connector run encountered too many errors, aborting. "
f"Last error: {last_failure}"
)
def _run_indexing(
db_session: Session,
index_attempt_id: int,
@@ -169,13 +247,10 @@ def _run_indexing(
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
TODO: do not change index attempt statuses here ... instead, set signals in redis
and allow the monitor function to clean them up
"""
start_time = time.time()
start_time = time.monotonic() # jsut used for logging
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_start:
raise ValueError(
@@ -221,6 +296,46 @@ def _run_indexing(
db_session=db_session_temp,
)
)
if last_successful_index_time > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt_start.search_settings_id,
db_session=db_session_temp,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# add start/end now that they have been set
index_attempt_start.poll_range_start = window_start
index_attempt_start.poll_range_end = window_end
db_session_temp.add(index_attempt_start)
db_session_temp.commit()
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=index_attempt_start.search_settings,
@@ -234,7 +349,6 @@ def _run_indexing(
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt_id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=(
@@ -246,63 +360,73 @@ def _run_indexing(
callback=callback,
)
tracer: OnyxTracer
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
tracer.start()
tracer.snap()
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
index_attempt_md = IndexAttemptMetadata(
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
total_failures = 0
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
run_end_dt = None
tracer_counter: int
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
),
source_type=db_connector.source,
)
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
tracer_counter = 0
try:
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
connector_runner = _get_connector_runner(
db_session=db_session_temp,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_loop_start = get_index_attempt(
db_session_temp, index_attempt_id
)
if not index_attempt_loop_start:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
connector_runner = _get_connector_runner(
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling.
if index_attempt.from_beginning:
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
db_session=db_session_temp,
attempt=index_attempt_loop_start,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
unresolved_errors = get_index_attempt_errors_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
unresolved_only=True,
db_session=db_session_temp,
)
doc_id_to_unresolved_errors: dict[
str, list[IndexAttemptError]
] = defaultdict(list)
for error in unresolved_errors:
if error.document_id:
doc_id_to_unresolved_errors[error.document_id].append(error)
entity_based_unresolved_errors = [
error for error in unresolved_errors if error.entity_id
]
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
@@ -312,42 +436,38 @@ def _run_indexing(
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
with get_session_with_tenant(tenant_id) as db_session_temp:
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
if (
(
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session_temp:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
# if it's deleting, we don't care if this is a secondary index
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(
db_session_temp, index_attempt_id
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
if not index_attempt_loop:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
# save the new checkpoint (if one is provided)
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing logic, so if no batch we can just continue
if document_batch is None:
continue
batch_description = []
doc_batch_cleaned = strip_null_characters(doc_batch)
doc_batch_cleaned = strip_null_characters(document_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
@@ -377,15 +497,51 @@ def _run_indexing(
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()
# resolve errors for documents that were successfully indexed
failed_document_ids = [
failure.failed_document.document_id
for failure in index_pipeline_result.failures
if failure.failed_document
]
successful_document_ids = [
document.id
for document in document_batch
if document.id not in failed_document_ids
]
for document_id in successful_document_ids:
with get_session_with_current_tenant() as db_session_temp:
if document_id in doc_id_to_unresolved_errors:
logger.info(
f"Resolving IndexAttemptError for document '{document_id}'"
)
for error in doc_id_to_unresolved_errors[document_id]:
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
# add brand new failures
if index_pipeline_result.failures:
total_failures += len(index_pipeline_result.failures)
with get_session_with_current_tenant() as db_session_temp:
for failure in index_pipeline_result.failures:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
_check_failure_threshold(
total_failures,
document_count,
batch_num,
index_pipeline_result.failures[-1],
)
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_tenant(tenant_id) as db_session_temp:
with get_session_with_current_tenant() as db_session_temp:
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
# so we need either to commit() or to use a new session
update_docs_indexed(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
@@ -397,126 +553,97 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
memory_tracer.increment_and_maybe_trace()
run_end_dt = window_end
if ctx.is_primary:
with get_session_with_tenant(tenant_id) as db_session_temp:
# `make sure the checkpoints aren't getting too large`at some regular interval
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
# save latest checkpoint
with get_session_with_current_tenant() as db_session_temp:
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "
f"{time.monotonic() - start_time} seconds"
)
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
memory_tracer.stop()
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or (
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
)
or (
index_attempt_loop is not None
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
)
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason="All batches exceptioned.",
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
memory_tracer.stop()
raise e
else:
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
elapsed_time = time.time() - start_time
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
with get_session_with_tenant(tenant_id) as db_session_temp:
if index_attempt_md.num_exceptions == 0:
memory_tracer.stop()
raise e
memory_tracer.stop()
elapsed_time = time.monotonic() - start_time
with get_session_with_current_tenant() as db_session_temp:
# resolve entity-based errors
for error in entity_based_unresolved_errors:
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
if total_failures == 0:
mark_attempt_succeeded(index_attempt_id, db_session_temp)
create_milestone_and_report(
@@ -535,7 +662,7 @@ def _run_indexing(
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(
f"Connector completed with some errors: "
f"exceptions={index_attempt_md.num_exceptions} "
f"failures={total_failures} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
@@ -547,7 +674,7 @@ def _run_indexing(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
run_dt=run_end_dt,
run_dt=window_end,
)
@@ -558,46 +685,43 @@ def run_indexing_entrypoint(
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
try:
if is_ee:
global_version.set_ee()
"""Don't swallow exceptions here ... propagate them up."""
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_current_tenant() as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
credential_id = attempt.connector_credential_pair.credential_id
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
credential_id = attempt.connector_credential_pair.credential_id
with get_session_with_current_tenant() as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
except Exception as e:
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)

View File

@@ -1,77 +0,0 @@
import tracemalloc
from onyx.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class OnyxTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
def start(self) -> None:
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
def stop(self) -> None:
tracemalloc.stop()
def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return
OnyxTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return
OnyxTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)

View File

@@ -27,8 +27,10 @@ from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -80,6 +82,26 @@ class Answer:
and not skip_explicit_tool_calling
)
rerank_settings = search_request.rerank_settings
using_cloud_reranking = (
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.
# remove SearchQuery entirely)
if (
force_use_tool.force_use
and search_tool
and force_use_tool.args
and force_use_tool.tool_name == search_tool.name
and QUERY_FIELD in force_use_tool.args
):
search_request.query = force_use_tool.args[QUERY_FIELD]
self.graph_inputs = GraphInputs(
search_request=search_request,
prompt_builder=prompt_builder,
@@ -94,7 +116,6 @@ class Answer:
force_use_tool=force_use_tool,
using_tool_calling_llm=using_tool_calling_llm,
)
assert db_session, "db_session must be provided for agentic persistence"
self.graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=chat_session_id,
@@ -104,6 +125,7 @@ class Answer:
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
allow_refinement=True,
allow_agent_reranking=allow_agent_reranking,
)
self.graph_config = GraphConfig(
inputs=self.graph_inputs,

View File

@@ -190,7 +190,8 @@ def create_chat_chain(
and previous_message.message_type == MessageType.ASSISTANT
and mainline_messages
):
mainline_messages[-1] = current_message
if current_message.refined_answer_improvement:
mainline_messages[-1] = current_message
else:
mainline_messages.append(current_message)

View File

@@ -142,6 +142,15 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class AgentMessageIDInfo(BaseModel):
level: int
message_id: int
class AgenticMessageResponseIDInfo(BaseModel):
agentic_message_ids: list[AgentMessageIDInfo]
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None

View File

@@ -7,10 +7,12 @@ from typing import cast
from sqlalchemy.orm import Session
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
from onyx.chat.models import AgenticMessageResponseIDInfo
from onyx.chat.models import AgentMessageIDInfo
from onyx.chat.models import AgentSearchPacket
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerPostInfo
@@ -143,9 +145,10 @@ from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
def _translate_citations(
@@ -307,6 +310,7 @@ ChatPacket = (
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| AgenticMessageResponseIDInfo
| StreamStopInfo
| AgentSearchPacket
)
@@ -342,7 +346,7 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
tenant_id = get_current_tenant_id()
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
@@ -631,6 +635,7 @@ def stream_chat_message_objects(
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
@@ -1015,7 +1020,7 @@ def stream_chat_message_objects(
if info.message_specific_citations
else None
),
error=None,
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
@@ -1033,6 +1038,7 @@ def stream_chat_message_objects(
next_level = 1
prev_message = gen_ai_response_message
agent_answers = answer.llm_answer_by_level()
agentic_message_ids = []
while next_level in agent_answers:
next_answer = agent_answers[next_level]
info = info_by_subq[
@@ -1053,7 +1059,12 @@ def stream_chat_message_objects(
citations=info.message_specific_citations.citation_map
if info.message_specific_citations
else None,
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
refined_answer_improvement=refined_answer_improvement,
is_agentic=True,
)
agentic_message_ids.append(
AgentMessageIDInfo(level=next_level, message_id=next_answer_message.id)
)
next_level += 1
prev_message = next_answer_message
@@ -1061,11 +1072,9 @@ def stream_chat_message_objects(
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
yield msg_detail_response
yield translate_db_message_to_chat_message_detail(gen_ai_response_message)
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)

View File

@@ -31,22 +31,9 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
AGENT_ANSWER_GENERATION_BY_FAST_LLM = (
os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true"
)
AGENT_RETRIEVAL_STATS = (
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
@@ -178,80 +165,172 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
) # 2000
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
)
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
) # 3
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
) # 30
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
) # 12
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
) # 6
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
)
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
) # 1
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
) # 4
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
)
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
)
GRAPH_VERSION_NAME: str = "a"

View File

@@ -169,6 +169,11 @@ POSTGRES_API_SERVER_POOL_SIZE = int(
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
# defaults to False
# generally should only be used for
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
@@ -621,6 +626,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
# Set to true to mock LLM responses for testing purposes

View File

@@ -98,9 +98,18 @@ CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# needs to be long enough to cover the maximum time it takes to download an object
# hard timeout applied by the watchdog to the indexing connector run
# to handle hung connectors
CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
# soft timeout for the lock taken by the indexing connector run
# allows the lock to eventually expire if the managing code around it dies
# if we can get callbacks as object bytes download, we could lower this a lot.
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 15 minutes
# hard termination should always fire first if the connector is hung
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
@@ -165,6 +174,9 @@ class DocumentSource(str, Enum):
EGNYTE = "egnyte"
AIRTABLE = "airtable"
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -243,6 +255,7 @@ class FileOrigin(str, Enum):
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
INDEXING_CHECKPOINT = "indexing_checkpoint"
OTHER = "other"
@@ -274,6 +287,7 @@ class OnyxCeleryQueues:
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
LLM_MODEL_UPDATE = "llm_model_update"
CHECKPOINT_CLEANUP = "checkpoint_cleanup"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
@@ -293,6 +307,7 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
@@ -368,6 +383,10 @@ class OnyxCeleryTask:
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
# Connector checkpoint cleanup
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
CLEANUP_CHECKPOINT = "cleanup_checkpoint"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"

View File

@@ -245,7 +245,7 @@ class AirtableConnector(LoadConnector):
return [(" ".join(combined) if combined else str(field_info), default_link)]
if isinstance(field_info, list):
return [(item, default_link) for item in field_info]
return [(str(item), default_link) for item in field_info]
return [(str(field_info), default_link)]
@@ -268,7 +268,7 @@ class AirtableConnector(LoadConnector):
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[Section], dict[str, Any]]:
) -> tuple[list[Section], dict[str, str | list[str]]]:
"""
Process a single Airtable field and return sections or metadata.
@@ -342,7 +342,7 @@ class AirtableConnector(LoadConnector):
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, Any] = {}
metadata: dict[str, str | list[str]] = {}
# Get primary field value if it exists
primary_field_value = (

View File

@@ -7,15 +7,22 @@ from typing import Optional
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from botocore.exceptions import PartialCredentialsError
from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import BlobType
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
@@ -240,6 +247,73 @@ class BlobStorageConnector(LoadConnector, PollConnector):
return None
def validate_connector_settings(self) -> None:
if self.s3_client is None:
raise ConnectorMissingCredentialError(
"Blob storage credentials not loaded."
)
if not self.bucket_name:
raise ConnectorValidationError(
"No bucket name was provided in connector settings."
)
try:
# We only fetch one object/page as a light-weight validation step.
# This ensures we trigger typical S3 permission checks (ListObjectsV2, etc.).
self.s3_client.list_objects_v2(
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
)
except NoCredentialsError:
raise ConnectorMissingCredentialError(
"No valid blob storage credentials found or provided to boto3."
)
except PartialCredentialsError:
raise ConnectorMissingCredentialError(
"Partial or incomplete blob storage credentials provided to boto3."
)
except ClientError as e:
error_code = e.response["Error"].get("Code", "")
status_code = e.response["ResponseMetadata"].get("HTTPStatusCode")
# Most common S3 error cases
if error_code in [
"AccessDenied",
"InvalidAccessKeyId",
"SignatureDoesNotMatch",
]:
if status_code == 403 or error_code == "AccessDenied":
raise InsufficientPermissionsError(
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
"Please check your bucket policy and/or IAM policy."
)
if status_code == 401 or error_code == "SignatureDoesNotMatch":
raise CredentialExpiredError(
"Provided blob storage credentials appear invalid or expired."
)
raise CredentialExpiredError(
f"Credential issue encountered ({error_code})."
)
if error_code == "NoSuchBucket" or status_code == 404:
raise ConnectorValidationError(
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
)
raise ConnectorValidationError(
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
)
except Exception as e:
# Catch-all for anything not captured by the above
# Since we are unsure of the error and it may not disable the connector,
# raise an unexpected error (does not disable connector)
raise UnexpectedError(
f"Unexpected error during blob storage settings validation: {e}"
)
if __name__ == "__main__":
credentials_dict = {

View File

@@ -5,6 +5,8 @@ import requests
class BookStackClientRequestFailedError(ConnectionError):
def __init__(self, status: int, error: str) -> None:
self.status_code = status
self.error = error
super().__init__(
"BookStack Client request failed with status {status}: {error}".format(
status=status, error=error

View File

@@ -7,8 +7,12 @@ from typing import Any
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.bookstack.client import BookStackApiClient
from onyx.connectors.bookstack.client import BookStackClientRequestFailedError
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@@ -214,3 +218,39 @@ class BookstackConnector(LoadConnector, PollConnector):
break
else:
time.sleep(0.2)
def validate_connector_settings(self) -> None:
"""
Validate that the BookStack credentials and connector settings are correct.
Specifically checks that we can make an authenticated request to BookStack.
"""
if not self.bookstack_client:
raise ConnectorMissingCredentialError(
"BookStack credentials have not been loaded."
)
try:
# Attempt to fetch a small batch of books (arbitrary endpoint) to verify credentials
_ = self.bookstack_client.get(
"/books", params={"count": "1", "offset": "0"}
)
except BookStackClientRequestFailedError as e:
# Check for HTTP status codes
if e.status_code == 401:
raise CredentialExpiredError(
"Your BookStack credentials appear to be invalid or expired (HTTP 401)."
) from e
elif e.status_code == 403:
raise InsufficientPermissionsError(
"The configured BookStack token does not have sufficient permissions (HTTP 403)."
) from e
else:
raise ConnectorValidationError(
f"Unexpected BookStack error (status={e.status_code}): {e}"
) from e
except Exception as exc:
raise ConnectorValidationError(
f"Unexpected error while validating BookStack connector settings: {exc}"
) from exc

View File

@@ -8,6 +8,7 @@ from typing import TypeVar
from urllib.parse import quote
from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from requests import HTTPError
from onyx.utils.logger import setup_logger
@@ -29,6 +30,16 @@ class ConfluenceRateLimitError(Exception):
pass
class ConfluenceUser(BaseModel):
user_id: str # accountId in Cloud, userKey in Server
username: str | None # Confluence Cloud doesn't give usernames
display_name: str
# Confluence Data Center doesn't give email back by default,
# have to fetch it with a different endpoint
email: str | None
type: str
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
@@ -275,21 +286,95 @@ class OnyxConfluence(Confluence):
self,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
) -> Iterator[ConfluenceUser]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
cql = "type=user"
url = "rest/api/search/user" if self.cloud else "rest/api/search"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
yield from self._paginate_url(url, limit)
if self.cloud:
cql = "type=user"
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
for user_result in self._paginate_url(url, limit):
# Example response:
# {
# 'user': {
# 'type': 'known',
# 'accountId': '712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'accountType': 'atlassian',
# 'email': 'chris@danswer.ai',
# 'publicName': 'Chris Weaver',
# 'profilePicture': {
# 'path': '/wiki/aa-avatar/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'width': 48,
# 'height': 48,
# 'isDefault': False
# },
# 'displayName': 'Chris Weaver',
# 'isExternalCollaborator': False,
# '_expandable': {
# 'operations': '',
# 'personalSpace': ''
# },
# '_links': {
# 'self': 'https://danswerai.atlassian.net/wiki/rest/api/user?accountId=712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d'
# }
# },
# 'title': 'Chris Weaver',
# 'excerpt': '',
# 'url': '/people/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'breadcrumbs': [],
# 'entityType': 'user',
# 'iconCssClass': 'aui-icon content-type-profile',
# 'lastModified': '2025-02-18T04:08:03.579Z',
# 'score': 0.0
# }
user = user_result["user"]
yield ConfluenceUser(
user_id=user["accountId"],
username=None,
display_name=user["displayName"],
email=user.get("email"),
type=user["accountType"],
)
else:
# https://developer.atlassian.com/server/confluence/rest/v900/api-group-user/#api-rest-api-user-list-get
# ^ is only available on data center deployments
# Example response:
# [
# {
# 'type': 'known',
# 'username': 'admin',
# 'userKey': '40281082950c5fe901950c61c55d0000',
# 'profilePicture': {
# 'path': '/images/icons/profilepics/default.svg',
# 'width': 48,
# 'height': 48,
# 'isDefault': True
# },
# 'displayName': 'Admin Test',
# '_links': {
# 'self': 'http://localhost:8090/rest/api/user?key=40281082950c5fe901950c61c55d0000'
# },
# '_expandable': {
# 'status': ''
# }
# }
# ]
for user in self._paginate_url("rest/api/user/list", limit):
yield ConfluenceUser(
user_id=user["userKey"],
username=user["username"],
display_name=user["displayName"],
email=None,
type=user.get("type", "user"),
)
def paginated_groups_by_user_retrieval(
self,
user: dict[str, Any],
user_id: str, # accountId in Cloud, userKey in Server
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
@@ -297,7 +382,7 @@ class OnyxConfluence(Confluence):
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user["accountId"] if self.cloud else user["userKey"]
user_value = user_id
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"

View File

@@ -1,11 +1,16 @@
import sys
import time
from collections.abc import Generator
from datetime import datetime
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.utils.logger import setup_logger
@@ -15,48 +20,139 @@ logger = setup_logger()
TimeRange = tuple[datetime, datetime]
class CheckpointOutputWrapper:
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
"""
def __init__(self) -> None:
self.next_checkpoint: ConnectorCheckpoint | None = None
def __call__(
self,
checkpoint_connector_generator: CheckpointOutput,
) -> Generator[
tuple[Document | None, ConnectorFailure | None, ConnectorCheckpoint | None],
None,
None,
]:
# grabs the final return value and stores it in the `next_checkpoint` variable
def _inner_wrapper(
checkpoint_connector_generator: CheckpointOutput,
) -> CheckpointOutput:
self.next_checkpoint = yield from checkpoint_connector_generator
return self.next_checkpoint # not used
for document_or_failure in _inner_wrapper(checkpoint_connector_generator):
if isinstance(document_or_failure, Document):
yield document_or_failure, None, None
elif isinstance(document_or_failure, ConnectorFailure):
yield None, document_or_failure, None
else:
raise ValueError(
f"Invalid document_or_failure type: {type(document_or_failure)}"
)
if self.next_checkpoint is None:
raise RuntimeError(
"Checkpoint is None. This should never happen - the connector should always return a checkpoint."
)
yield None, None, self.next_checkpoint
class ConnectorRunner:
"""
Handles:
- Batching
- Additional exception logging
- Combining different connector types to a single interface
"""
def __init__(
self,
connector: BaseConnector,
batch_size: int,
time_range: TimeRange | None = None,
fail_loudly: bool = False,
):
self.connector = connector
self.time_range = time_range
self.batch_size = batch_size
if isinstance(self.connector, PollConnector):
if time_range is None:
raise ValueError("time_range is required for PollConnector")
self.doc_batch: list[Document] = []
self.doc_batch_generator = self.connector.poll_source(
time_range[0].timestamp(), time_range[1].timestamp()
)
elif isinstance(self.connector, LoadConnector):
if time_range and fail_loudly:
raise ValueError(
"time_range specified, but passed in connector is not a PollConnector"
)
self.doc_batch_generator = self.connector.load_from_state()
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
def run(self) -> GenerateDocumentsOutput:
def run(
self, checkpoint: ConnectorCheckpoint
) -> Generator[
tuple[
list[Document] | None, ConnectorFailure | None, ConnectorCheckpoint | None
],
None,
None,
]:
"""Adds additional exception logging to the connector."""
try:
start = time.monotonic()
for batch in self.doc_batch_generator:
# to know how long connector is taking
logger.debug(
f"Connector took {time.monotonic() - start} seconds to build a batch."
)
yield batch
if isinstance(self.connector, CheckpointConnector):
if self.time_range is None:
raise ValueError("time_range is required for CheckpointConnector")
start = time.monotonic()
checkpoint_connector_generator = self.connector.load_from_checkpoint(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
checkpoint=checkpoint,
)
next_checkpoint: ConnectorCheckpoint | None = None
# this is guaranteed to always run at least once with next_checkpoint being non-None
for document, failure, next_checkpoint in CheckpointOutputWrapper()(
checkpoint_connector_generator
):
if document is not None:
self.doc_batch.append(document)
if failure is not None:
yield None, failure, None
if len(self.doc_batch) >= self.batch_size:
yield self.doc_batch, None, None
self.doc_batch = []
# yield remaining documents
if len(self.doc_batch) > 0:
yield self.doc_batch, None, None
self.doc_batch = []
yield None, None, next_checkpoint
logger.debug(
f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint."
)
else:
finished_checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
finished_checkpoint.has_more = False
if isinstance(self.connector, PollConnector):
if self.time_range is None:
raise ValueError("time_range is required for PollConnector")
for document_batch in self.connector.poll_source(
start=self.time_range[0].timestamp(),
end=self.time_range[1].timestamp(),
):
yield document_batch, None, None
yield None, None, finished_checkpoint
elif isinstance(self.connector, LoadConnector):
for document_batch in self.connector.load_from_state():
yield document_batch, None, None
yield None, None, finished_checkpoint
else:
raise ValueError(f"Invalid connector. type: {type(self.connector)}")
except Exception:
exc_type, _, exc_traceback = sys.exc_info()
@@ -76,6 +172,6 @@ class ConnectorRunner:
)
logger.error(
f"Error in connector. type: {exc_type};\n"
f"local_vars below -> \n{local_vars_str}"
f"local_vars below -> \n{local_vars_str[:1024]}"
)
raise

View File

@@ -1,3 +1,4 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
@@ -24,16 +25,22 @@ def datetime_to_utc(dt: datetime) -> datetime:
def time_str_to_utc(datetime_str: str) -> datetime:
# Remove all timezone abbreviations in parentheses
datetime_str = re.sub(r"\([A-Z]+\)", "", datetime_str).strip()
# Remove any remaining parentheses and their contents
datetime_str = re.sub(r"\(.*?\)", "", datetime_str).strip()
try:
dt = parse(datetime_str)
except ValueError:
# Handle malformed timezone by attempting to fix common format issues
# Fix common format issues (e.g. "0000" => "+0000")
if "0000" in datetime_str:
# Convert "0000" to "+0000" for proper timezone parsing
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
dt = parse(fixed_dt_str)
datetime_str = datetime_str.replace(" 0000", " +0000")
dt = parse(datetime_str)
else:
raise
return datetime_to_utc(dt)

View File

@@ -4,12 +4,16 @@ from typing import Any
from dropbox import Dropbox # type: ignore
from dropbox.exceptions import ApiError # type:ignore
from dropbox.exceptions import AuthError # type:ignore
from dropbox.files import FileMetadata # type:ignore
from dropbox.files import FolderMetadata # type:ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialInvalidError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@@ -141,6 +145,29 @@ class DropboxConnector(LoadConnector, PollConnector):
return None
def validate_connector_settings(self) -> None:
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox credentials not loaded.")
try:
self.dropbox_client.files_list_folder(path="", limit=1)
except AuthError as e:
logger.exception("Failed to validate Dropbox credentials")
raise CredentialInvalidError(f"Dropbox credential is invalid: {e.error}")
except ApiError as e:
if (
e.error is not None
and "insufficient_permissions" in str(e.error).lower()
):
raise InsufficientPermissionsError(
"Your Dropbox token does not have sufficient permissions."
)
raise ConnectorValidationError(
f"Unexpected Dropbox error during validation: {e.user_message_text or e}"
)
except Exception as e:
raise Exception(f"Unexpected error during Dropbox settings validation: {e}")
if __name__ == "__main__":
import os

View File

@@ -30,12 +30,15 @@ from onyx.connectors.google_site.connector import GoogleSitesConnector
from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.linear.connector import LinearConnector
from onyx.connectors.loopio.connector import LoopioConnector
from onyx.connectors.mediawiki.wiki import MediaWikiConnector
from onyx.connectors.mock_connector.connector import MockConnector
from onyx.connectors.models import InputType
from onyx.connectors.notion.connector import NotionConnector
from onyx.connectors.onyx_jira.connector import JiraConnector
@@ -43,15 +46,18 @@ from onyx.connectors.productboard.connector import ProductboardConnector
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.sharepoint.connector import SharepointConnector
from onyx.connectors.slab.connector import SlabConnector
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.teams.connector import TeamsConnector
from onyx.connectors.web.connector import WebConnector
from onyx.connectors.wikipedia.connector import WikipediaConnector
from onyx.connectors.xenforo.connector import XenforoConnector
from onyx.connectors.zendesk.connector import ZendeskConnector
from onyx.connectors.zulip.connector import ZulipConnector
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.models import Credential
from onyx.db.models import User
class ConnectorMissingException(Exception):
@@ -66,8 +72,8 @@ def identify_connector_class(
DocumentSource.WEB: WebConnector,
DocumentSource.FILE: LocalFileConnector,
DocumentSource.SLACK: {
InputType.POLL: SlackPollConnector,
InputType.SLIM_RETRIEVAL: SlackPollConnector,
InputType.POLL: SlackConnector,
InputType.SLIM_RETRIEVAL: SlackConnector,
},
DocumentSource.GITHUB: GithubConnector,
DocumentSource.GMAIL: GmailConnector,
@@ -109,6 +115,8 @@ def identify_connector_class(
DocumentSource.FIREFLIES: FirefliesConnector,
DocumentSource.EGNYTE: EgnyteConnector,
DocumentSource.AIRTABLE: AirtableConnector,
# just for integration tests
DocumentSource.MOCK_CONNECTOR: MockConnector,
}
connector_by_source = connector_map.get(source, {})
@@ -125,10 +133,23 @@ def identify_connector_class(
if any(
[
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector),
input_type == InputType.POLL and not issubclass(connector, PollConnector),
input_type == InputType.EVENT and not issubclass(connector, EventConnector),
(
input_type == InputType.LOAD_STATE
and not issubclass(connector, LoadConnector)
),
(
input_type == InputType.POLL
# either poll or checkpoint works for this, in the future
# all connectors should be checkpoint connectors
and (
not issubclass(connector, PollConnector)
and not issubclass(connector, CheckpointConnector)
)
),
(
input_type == InputType.EVENT
and not issubclass(connector, EventConnector)
),
]
):
raise ConnectorMissingException(
@@ -157,3 +178,43 @@ def instantiate_connector(
backend_update_credential_json(credential, new_credentials, db_session)
return connector
def validate_ccpair_for_user(
connector_id: int,
credential_id: int,
db_session: Session,
user: User | None,
tenant_id: str | None,
) -> None:
# Validate the connector settings
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
get_editable=False,
)
if not connector:
raise ValueError("Connector not found")
if connector.source == DocumentSource.INGESTION_API:
return
if not credential:
raise ValueError("Credential not found")
try:
runnable_connector = instantiate_connector(
db_session=db_session,
source=connector.source,
input_type=connector.input_type,
connector_specific_config=connector.connector_specific_config,
credential=credential,
tenant_id=tenant_id,
)
except Exception as e:
raise ConnectorValidationError(str(e))
runnable_connector.validate_connector_settings()

View File

@@ -181,7 +181,7 @@ class LocalFileConnector(LoadConnector):
documents: list[Document] = []
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
with get_session_with_tenant(self.tenant_id) as db_session:
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(

View File

@@ -187,12 +187,12 @@ class FirefliesConnector(PollConnector, LoadConnector):
return self._process_transcripts()
def poll_source(
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(
start_unixtime, tz=timezone.utc
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)

View File

@@ -9,6 +9,7 @@ from typing import cast
from github import Github
from github import RateLimitExceededException
from github import Repository
from github.GithubException import GithubException
from github.Issue import Issue
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
@@ -16,17 +17,20 @@ from github.PullRequest import PullRequest
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -226,6 +230,48 @@ class GithubConnector(LoadConnector, PollConnector):
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
def validate_connector_settings(self) -> None:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
if not self.repo_owner or not self.repo_name:
raise ConnectorValidationError(
"Invalid connector settings: 'repo_owner' and 'repo_name' must be provided."
)
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
except RateLimitExceededException:
raise UnexpectedError(
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
)
except GithubException as e:
if e.status == 401:
raise CredentialExpiredError(
"GitHub credential appears to be invalid or expired (HTTP 401)."
)
elif e.status == 403:
raise InsufficientPermissionsError(
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
)
if __name__ == "__main__":
import os

View File

@@ -297,6 +297,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
full_threads = execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().get,

View File

@@ -87,16 +87,18 @@ class HubSpotConnector(LoadConnector, PollConnector):
contact = api_client.crm.contacts.basic_api.get_by_id(
contact_id=contact.id
)
associated_emails.append(contact.properties["email"])
email = contact.properties.get("email")
if email is not None:
associated_emails.append(email)
if notes:
for note in notes.results:
note = api_client.crm.objects.notes.basic_api.get_by_id(
note_id=note.id, properties=["content", "hs_body_preview"]
)
if note.properties["hs_body_preview"] is None:
continue
associated_notes.append(note.properties["hs_body_preview"])
preview = note.properties.get("hs_body_preview")
if preview is not None:
associated_notes.append(preview)
associated_emails_str = " ,".join(associated_emails)
associated_notes_str = " ".join(associated_notes)

View File

@@ -1,19 +1,22 @@
import abc
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Iterator[list[Document]]
GenerateSlimDocumentOutput = Iterator[list[SlimDocument]]
CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpoint]
class BaseConnector(abc.ABC):
@@ -41,6 +44,14 @@ class BaseConnector(abc.ABC):
raise RuntimeError(custom_parser_req_msg)
return metadata_lines
def validate_connector_settings(self) -> None:
"""
Override this if your connector needs to validate credentials or settings.
Raise an exception if invalid, otherwise do nothing.
Default is a no-op (always successful).
"""
# Large set update or reindex, generally pulling a complete state or from a savestate file
class LoadConnector(BaseConnector):
@@ -105,3 +116,76 @@ class EventConnector(BaseConnector):
@abc.abstractmethod
def handle_event(self, event: Any) -> GenerateDocumentsOutput:
raise NotImplementedError
class CheckpointConnector(BaseConnector):
@abc.abstractmethod
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Yields back documents or failures. Final return is the new checkpoint.
Final return can be access via either:
```
try:
for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value # Extracting the return value
print(checkpoint)
```
OR
```
checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint)
```
"""
raise NotImplementedError
class ConnectorValidationError(Exception):
"""General exception for connector validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class UnexpectedError(Exception):
"""Raised when an unexpected error occurs during connector validation.
Unexpected errors don't necessarily mean the credential is invalid,
but rather that there was an error during the validation process
or we encountered a currently unhandled error case.
"""
def __init__(self, message: str = "Unexpected error during connector validation"):
super().__init__(message)
class CredentialInvalidError(ConnectorValidationError):
"""Raised when a connector's credential is invalid."""
def __init__(self, message: str = "Credential is invalid"):
super().__init__(message)
class CredentialExpiredError(ConnectorValidationError):
"""Raised when a connector's credential is expired."""
def __init__(self, message: str = "Credential has expired"):
super().__init__(message)
class InsufficientPermissionsError(ConnectorValidationError):
"""Raised when the credential does not have sufficient API permissions."""
def __init__(
self, message: str = "Insufficient permissions for the requested operation"
):
super().__init__(message)

View File

@@ -0,0 +1,86 @@
from typing import Any
import httpx
from pydantic import BaseModel
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.utils.logger import setup_logger
logger = setup_logger()
class SingleConnectorYield(BaseModel):
documents: list[Document]
checkpoint: ConnectorCheckpoint
failures: list[ConnectorFailure]
unhandled_exception: str | None = None
class MockConnector(CheckpointConnector):
def __init__(
self,
mock_server_host: str,
mock_server_port: int,
) -> None:
self.mock_server_host = mock_server_host
self.mock_server_port = mock_server_port
self.client = httpx.Client(timeout=30.0)
self.connector_yields: list[SingleConnectorYield] | None = None
self.current_yield_index: int = 0
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
response = self.client.get(self._get_mock_server_url("get-documents"))
response.raise_for_status()
data = response.json()
self.connector_yields = [
SingleConnectorYield(**yield_data) for yield_data in data
]
return None
def _get_mock_server_url(self, endpoint: str) -> str:
return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}"
def _save_checkpoint(self, checkpoint: ConnectorCheckpoint) -> None:
response = self.client.post(
self._get_mock_server_url("add-checkpoint"),
json=checkpoint.model_dump(mode="json"),
)
response.raise_for_status()
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
if self.connector_yields is None:
raise ValueError("No connector yields configured")
# Save the checkpoint to the mock server
self._save_checkpoint(checkpoint)
yield_index = self.current_yield_index
self.current_yield_index += 1
current_yield = self.connector_yields[yield_index]
# If the current yield has an unhandled exception, raise it
# This is used to simulate an unhandled failure in the connector.
if current_yield.unhandled_exception:
raise RuntimeError(current_yield.unhandled_exception)
# yield all documents
for document in current_yield.documents:
yield document
for failure in current_yield.failures:
yield failure
return current_yield.checkpoint

View File

@@ -3,6 +3,7 @@ from enum import Enum
from typing import Any
from pydantic import BaseModel
from pydantic import model_validator
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
@@ -187,36 +188,48 @@ class SlimDocument(BaseModel):
perm_sync_data: Any | None = None
class DocumentErrorSummary(BaseModel):
id: str
semantic_id: str
section_link: str | None
@classmethod
def from_document(cls, doc: Document) -> "DocumentErrorSummary":
section_link = doc.sections[0].link if len(doc.sections) > 0 else None
return cls(
id=doc.id, semantic_id=doc.semantic_identifier, section_link=section_link
)
@classmethod
def from_dict(cls, data: dict) -> "DocumentErrorSummary":
return cls(
id=str(data.get("id")),
semantic_id=str(data.get("semantic_id")),
section_link=str(data.get("section_link")),
)
def to_dict(self) -> dict[str, str | None]:
return {
"id": self.id,
"semantic_id": self.semantic_id,
"section_link": self.section_link,
}
class IndexAttemptMetadata(BaseModel):
batch_num: int | None = None
num_exceptions: int = 0
connector_id: int
credential_id: int
class ConnectorCheckpoint(BaseModel):
# TODO: maybe move this to something disk-based to handle extremely large checkpoints?
checkpoint_content: dict
has_more: bool
@classmethod
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
class DocumentFailure(BaseModel):
document_id: str
document_link: str | None = None
class EntityFailure(BaseModel):
entity_id: str
missed_time_range: tuple[datetime, datetime] | None = None
class ConnectorFailure(BaseModel):
failed_document: DocumentFailure | None = None
failed_entity: EntityFailure | None = None
failure_message: str
exception: Exception | None = None
model_config = {"arbitrary_types_allowed": True}
@model_validator(mode="before")
def check_failed_fields(cls, values: dict) -> dict:
failed_document = values.get("failed_document")
failed_entity = values.get("failed_entity")
if (failed_document is None and failed_entity is None) or (
failed_document is not None and failed_entity is not None
):
raise ValueError(
"Exactly one of 'failed_document' or 'failed_entity' must be specified."
)
return values

View File

@@ -7,6 +7,7 @@ from datetime import timezone
from typing import Any
from typing import Optional
import requests
from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -15,10 +16,14 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
)
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
@@ -616,6 +621,64 @@ class NotionConnector(LoadConnector, PollConnector):
else:
break
def validate_connector_settings(self) -> None:
if not self.headers.get("Authorization"):
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
try:
# We'll do a minimal search call (page_size=1) to confirm accessibility
if self.root_page_id:
# If root_page_id is set, fetch the specific page
res = rl_requests.get(
f"https://api.notion.com/v1/pages/{self.root_page_id}",
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
else:
# If root_page_id is not set, perform a minimal search
test_query = {
"filter": {"property": "object", "value": "page"},
"page_size": 1,
}
res = rl_requests.post(
"https://api.notion.com/v1/search",
headers=self.headers,
json=test_query,
timeout=_NOTION_CALL_TIMEOUT,
)
res.raise_for_status()
except requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else None
if status_code == 401:
raise CredentialExpiredError(
"Notion credential appears to be invalid or expired (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Notion token does not have sufficient permissions (HTTP 403)."
)
elif status_code == 404:
# Typically means resource not found or not shared. Could be root_page_id is invalid.
raise ConnectorValidationError(
"Notion resource not found or not shared with the integration (HTTP 404)."
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Notion rate-limits being exceeded (HTTP 429). "
"Please try again later."
)
else:
raise Exception(
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
) from http_err
except Exception as exc:
raise Exception(
f"Unexpected error during Notion settings validation: {exc}"
)
if __name__ == "__main__":
import os

View File

@@ -12,8 +12,11 @@ from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@@ -272,6 +275,40 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
yield slim_doc_batch
def validate_connector_settings(self) -> None:
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
if not self._jira_project:
raise ConnectorValidationError(
"Invalid connector settings: 'jira_project' must be provided."
)
try:
self.jira_client.project(self._jira_project)
except Exception as e:
status_code = getattr(e, "status_code", None)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
)
elif status_code == 404:
raise ConnectorValidationError(
f"Jira project not found with key: {self._jira_project}"
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
else:
raise Exception(f"Unexpected Jira error during validation: {e}")
if __name__ == "__main__":
import os

View File

@@ -1,10 +1,16 @@
import contextvars
import copy
import re
from collections.abc import Callable
from collections.abc import Generator
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypedDict
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
@@ -12,14 +18,22 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.slack.utils import expert_info_from_slack_id
@@ -33,6 +47,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_SLACK_LIMIT = 900
ChannelType = dict[str, Any]
MessageType = dict[str, Any]
@@ -40,6 +56,13 @@ MessageType = dict[str, Any]
ThreadType = list[MessageType]
class SlackCheckpointContent(TypedDict):
channel_ids: list[str]
channel_completion_map: dict[str, str]
current_channel: ChannelType | None
seen_thread_ts: list[str]
def _collect_paginated_channels(
client: WebClient,
exclude_archived: bool,
@@ -140,6 +163,10 @@ def get_latest_message_time(thread: ThreadType) -> datetime:
return datetime.fromtimestamp(max_ts, tz=timezone.utc)
def _build_doc_id(channel_id: str, thread_ts: str) -> str:
return f"{channel_id}__{thread_ts}"
def thread_to_doc(
channel: ChannelType,
thread: ThreadType,
@@ -182,7 +209,7 @@ def thread_to_doc(
)
return Document(
id=f"{channel_id}__{thread[0]['ts']}",
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
sections=[
Section(
link=get_message_link(event=m, client=client, channel_id=channel_id),
@@ -267,64 +294,97 @@ def filter_channels(
]
def _get_all_docs(
def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
"""Get a channel by its ID.
Args:
client: The Slack WebClient instance
channel_id: The ID of the channel to fetch
Returns:
The channel information
Raises:
SlackApiError: If the channel cannot be fetched
"""
response = make_slack_api_call_w_retries(
client.conversations_info,
channel=channel_id,
)
return cast(ChannelType, response["channel"])
def _get_messages(
channel: ChannelType,
client: WebClient,
channels: list[str] | None = None,
channel_name_regex_enabled: bool = False,
oldest: str | None = None,
latest: str | None = None,
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Generator[Document, None, None]:
"""Get all documents in the workspace, channel by channel"""
slack_cleaner = SlackTextCleaner(client=client)
) -> tuple[list[MessageType], bool]:
"""Slack goes from newest to oldest."""
# Cache to prevent refetching via API since users
user_cache: dict[str, BasicExpertInfo | None] = {}
# have to be in the channel in order to read messages
if not channel["is_member"]:
make_slack_api_call_w_retries(
client.conversations_join,
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
all_channels = get_channels(client)
filtered_channels = filter_channels(
all_channels, channels, channel_name_regex_enabled
response = make_slack_api_call_w_retries(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
latest=latest,
limit=_SLACK_LIMIT,
)
response.validate()
for channel in filtered_channels:
channel_docs = 0
channel_message_batches = get_channel_messages(
client=client, channel=channel, oldest=oldest, latest=latest
messages = cast(list[MessageType], response.get("messages", []))
cursor = cast(dict[str, Any], response.get("response_metadata", {})).get(
"next_cursor", ""
)
has_more = bool(cursor)
return messages, has_more
def _message_to_doc(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> Document | None:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
return None
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
return thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
seen_thread_ts: set[str] = set()
for message_batch in channel_message_batches:
for message in message_batch:
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
continue
seen_thread_ts.add(thread_ts)
thread = get_thread(
client=client, channel_id=channel["id"], thread_id=thread_ts
)
filtered_thread = [
message for message in thread if not msg_filter_func(message)
]
elif not msg_filter_func(message):
filtered_thread = [message]
if filtered_thread:
channel_docs += 1
yield thread_to_doc(
channel=channel,
thread=filtered_thread,
slack_cleaner=slack_cleaner,
client=client,
user_cache=user_cache,
)
logger.info(
f"Pulled {channel_docs} documents from slack channel {channel['name']}"
)
return None
def _get_all_doc_ids(
@@ -368,7 +428,7 @@ def _get_all_doc_ids(
for message_ts in message_ts_set:
channel_metadata_list.append(
SlimDocument(
id=f"{channel_id}__{message_ts}",
id=_build_doc_id(channel_id=channel_id, thread_ts=message_ts),
perm_sync_data={"channel_id": channel_id},
)
)
@@ -376,7 +436,51 @@ def _get_all_doc_ids(
yield channel_metadata_list
class SlackPollConnector(PollConnector, SlimConnector):
def _process_message(
message: MessageType,
client: WebClient,
channel: ChannelType,
slack_cleaner: SlackTextCleaner,
user_cache: dict[str, BasicExpertInfo | None],
seen_thread_ts: set[str],
msg_filter_func: Callable[[MessageType], bool] = default_msg_filter,
) -> tuple[Document | None, str | None, ConnectorFailure | None]:
thread_ts = message.get("thread_ts")
try:
# causes random failures for testing checkpointing / continue on failure
# import random
# if random.random() > 0.95:
# raise RuntimeError("Random failure :P")
doc = _message_to_doc(
message=message,
client=client,
channel=channel,
slack_cleaner=slack_cleaner,
user_cache=user_cache,
seen_thread_ts=seen_thread_ts,
msg_filter_func=msg_filter_func,
)
return (doc, thread_ts, None)
except Exception as e:
logger.exception(f"Error processing message {message['ts']}")
return (
None,
thread_ts,
ConnectorFailure(
failed_document=DocumentFailure(
document_id=_build_doc_id(
channel_id=channel["id"], thread_ts=(thread_ts or message["ts"])
),
document_link=get_message_link(message, client, channel["id"]),
),
failure_message=str(e),
exception=e,
),
)
class SlackConnector(SlimConnector, CheckpointConnector):
def __init__(
self,
channels: list[str] | None = None,
@@ -390,9 +494,14 @@ class SlackPollConnector(PollConnector, SlimConnector):
self.batch_size = batch_size
self.client: WebClient | None = None
# just used for efficiency
self.text_cleaner: SlackTextCleaner | None = None
self.user_cache: dict[str, BasicExpertInfo | None] = {}
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
bot_token = credentials["slack_bot_token"]
self.client = WebClient(token=bot_token)
self.text_cleaner = SlackTextCleaner(client=self.client)
return None
def retrieve_all_slim_documents(
@@ -411,30 +520,205 @@ class SlackPollConnector(PollConnector, SlimConnector):
callback=callback,
)
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
if self.client is None:
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: ConnectorCheckpoint,
) -> CheckpointOutput:
"""Rough outline:
Step 1: Get all channels, yield back Checkpoint.
Step 2: Loop through each channel. For each channel:
Step 2.1: Get messages within the time range.
Step 2.2: Process messages in parallel, yield back docs.
Step 2.3: Update checkpoint with new_latest, seen_thread_ts, and current_channel.
Slack returns messages from newest to oldest, so we need to keep track of
the latest message we've seen in each channel.
Step 2.4: If there are no more messages in the channel, switch the current
channel to the next channel.
"""
if self.client is None or self.text_cleaner is None:
raise ConnectorMissingCredentialError("Slack")
documents: list[Document] = []
for document in _get_all_docs(
client=self.client,
channels=self.channels,
channel_name_regex_enabled=self.channel_regex_enabled,
# NOTE: need to impute to `None` instead of using 0.0, since Slack will
# throw an error if we use 0.0 on an account without infinite data
# retention
oldest=str(start) if start else None,
latest=str(end),
):
documents.append(document)
if len(documents) >= self.batch_size:
yield documents
documents = []
checkpoint_content = cast(
SlackCheckpointContent,
(
copy.deepcopy(checkpoint.checkpoint_content)
or {
"channel_ids": None,
"channel_completion_map": {},
"current_channel": None,
"seen_thread_ts": [],
}
),
)
if documents:
yield documents
# if this is the very first time we've called this, need to
# get all relevant channels and save them into the checkpoint
if checkpoint_content["channel_ids"] is None:
raw_channels = get_channels(self.client)
filtered_channels = filter_channels(
raw_channels, self.channels, self.channel_regex_enabled
)
if len(filtered_channels) == 0:
return checkpoint
checkpoint_content["channel_ids"] = [c["id"] for c in filtered_channels]
checkpoint_content["current_channel"] = filtered_channels[0]
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=True,
)
return checkpoint
final_channel_ids = checkpoint_content["channel_ids"]
channel = checkpoint_content["current_channel"]
if channel is None:
raise ValueError("current_channel key not found in checkpoint")
channel_id = channel["id"]
if channel_id not in final_channel_ids:
raise ValueError(f"Channel {channel_id} not found in checkpoint")
oldest = str(start) if start else None
latest = checkpoint_content["channel_completion_map"].get(channel_id, str(end))
seen_thread_ts = set(checkpoint_content["seen_thread_ts"])
try:
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
)
message_batch, has_more_in_channel = _get_messages(
channel, self.client, oldest, latest
)
new_latest = message_batch[-1]["ts"] if message_batch else latest
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=8) as executor:
futures: list[Future] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
futures.append(
executor.submit(
current_context.run,
_process_message,
message=message,
client=self.client,
channel=channel,
slack_cleaner=self.text_cleaner,
user_cache=self.user_cache,
seen_thread_ts=seen_thread_ts,
)
)
for future in as_completed(futures):
doc, thread_ts, failures = future.result()
if doc:
# handle race conditions here since this is single
# threaded. Multi-threaded _process_message reads from this
# but since this is single threaded, we won't run into simul
# writes. At worst, we can duplicate a thread, which will be
# deduped later on.
if thread_ts not in seen_thread_ts:
yield doc
if thread_ts:
seen_thread_ts.add(thread_ts)
elif failures:
for failure in failures:
yield failure
checkpoint_content["seen_thread_ts"] = list(seen_thread_ts)
checkpoint_content["channel_completion_map"][channel["id"]] = new_latest
if has_more_in_channel:
checkpoint_content["current_channel"] = channel
else:
new_channel_id = next(
(
channel_id
for channel_id in final_channel_ids
if channel_id
not in checkpoint_content["channel_completion_map"]
),
None,
)
if new_channel_id:
new_channel = _get_channel_by_id(self.client, new_channel_id)
checkpoint_content["current_channel"] = new_channel
else:
checkpoint_content["current_channel"] = None
checkpoint = ConnectorCheckpoint(
checkpoint_content=checkpoint_content, # type: ignore
has_more=checkpoint_content["current_channel"] is not None,
)
return checkpoint
except Exception as e:
logger.exception(f"Error processing channel {channel['name']}")
yield ConnectorFailure(
failed_entity=EntityFailure(
entity_id=channel["id"],
missed_time_range=(
datetime.fromtimestamp(start, tz=timezone.utc),
datetime.fromtimestamp(end, tz=timezone.utc),
),
),
failure_message=str(e),
exception=e,
)
return checkpoint
def validate_connector_settings(self) -> None:
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
try:
# Minimal API call to confirm we can list channels
# We set limit=1 for a lightweight check
response = self.client.conversations_list(limit=1, types=["public_channel"])
# Just ensure Slack responded "ok: True"
if not response.get("ok", False):
error_msg = response.get("error", "Unknown error from Slack")
if error_msg == "invalid_auth":
raise ConnectorValidationError(
f"Invalid or expired Slack bot token ({error_msg})."
)
elif error_msg == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({error_msg})."
)
raise UnexpectedError(f"Slack API returned a failure: {error_msg}")
except SlackApiError as e:
slack_error = e.response.get("error", "")
if slack_error == "missing_scope":
# The needed scope is typically "channels:read" or "groups:read"
# for viewing channels. The error message might also contain the
# specific scope needed vs. provided.
raise InsufficientPermissionsError(
"Slack bot token lacks the necessary scope to list channels. "
"Please ensure your Slack app has 'channels:read' (or 'groups:read' for private channels) enabled."
)
elif slack_error == "invalid_auth":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
elif slack_error == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
else:
# Generic Slack error
raise UnexpectedError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
except Exception as e:
# Catch-all for unexpected exceptions
raise UnexpectedError(
f"Unexpected error during Slack settings validation: {e}"
)
if __name__ == "__main__":
@@ -442,7 +726,7 @@ if __name__ == "__main__":
import time
slack_channel = os.environ.get("SLACK_CHANNEL")
connector = SlackPollConnector(
connector = SlackConnector(
channels=[slack_channel] if slack_channel else None,
)
connector.load_credentials({"slack_bot_token": os.environ["SLACK_BOT_TOKEN"]})
@@ -450,6 +734,17 @@ if __name__ == "__main__":
current = time.time()
one_day_ago = current - 24 * 60 * 60 # 1 day
document_batches = connector.poll_source(one_day_ago, current)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
print(next(document_batches))
gen = connector.load_from_checkpoint(one_day_ago, current, checkpoint)
try:
for document_or_failure in gen:
if isinstance(document_or_failure, Document):
print(document_or_failure)
elif isinstance(document_or_failure, ConnectorFailure):
print(document_or_failure)
except StopIteration as e:
checkpoint = e.value
print("Next checkpoint:", checkpoint)
print("Next checkpoint:", checkpoint)

View File

@@ -34,9 +34,14 @@ def get_message_link(
) -> str:
channel_id = channel_id or event["channel"]
message_ts = event["ts"]
response = client.chat_getPermalink(channel=channel_id, message_ts=message_ts)
permalink = response["permalink"]
return permalink
message_ts_without_dot = message_ts.replace(".", "")
thread_ts = event.get("thread_ts")
base_url = get_base_url(client.token)
link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + (
f"?thread_ts={thread_ts}" if thread_ts else ""
)
return link
def _make_slack_api_call_paginated(

View File

@@ -5,6 +5,7 @@ from typing import Any
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.runtime.client_request_exception import ClientRequestException # type: ignore
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.chats.messages.message import ChatMessage # type: ignore
from office365.teams.team import Team # type: ignore
@@ -12,10 +13,14 @@ from office365.teams.team import Team # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
@@ -279,6 +284,64 @@ class TeamsConnector(LoadConnector, PollConnector):
end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
def validate_connector_settings(self) -> None:
"""
Validate that we can connect to Microsoft Teams with the provided MSAL/Graph credentials
and that we can see at least one Team. If the user has specified a list of Teams by name,
confirm at least one of them is found.
Raises:
ConnectorMissingCredentialError: If the Graph client is not yet set (missing credentials).
CredentialExpiredError: If credentials appear invalid/expired (e.g. 401 Unauthorized).
InsufficientPermissionsError: If the app lacks required permissions to read Teams.
ConnectorValidationError: If no Teams are found, or if requested Teams are not found.
"""
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
try:
# Minimal call to confirm we can retrieve Teams
found_teams = self._get_all_teams()
# You may optionally catch the Graph/Office365 request exception if available:
except ClientRequestException as e:
status_code = e.response.status_code
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
else:
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()
if (
"unauthorized" in error_str
or "401" in error_str
or "invalid_grant" in error_str
):
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials."
)
elif "forbidden" in error_str or "403" in error_str:
raise InsufficientPermissionsError(
"App lacks required permissions to read from Microsoft Teams."
)
raise ConnectorValidationError(
f"Unexpected error during Teams validation: {e}"
)
# If we get this far, the Graph call succeeded. Check for presence of Teams:
if not found_teams:
raise ConnectorValidationError(
"No Teams found for the given credentials. "
"Either there are no Teams in this tenant, or your app does not have permission to view them."
)
if __name__ == "__main__":
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))

View File

@@ -25,8 +25,12 @@ from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_TOKEN_URL
from onyx.configs.app_configs import WEB_CONNECTOR_VALIDATE_URLS
from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import read_pdf_file
@@ -37,6 +41,8 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
# Given a base site, index everything under that path
@@ -170,26 +176,35 @@ def start_playwright() -> Tuple[Playwright, BrowserContext]:
def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
response = requests.get(sitemap_url)
response.raise_for_status()
try:
response = requests.get(sitemap_url)
response.raise_for_status()
soup = BeautifulSoup(response.content, "html.parser")
urls = [
_ensure_absolute_url(sitemap_url, loc_tag.text)
for loc_tag in soup.find_all("loc")
]
soup = BeautifulSoup(response.content, "html.parser")
urls = [
_ensure_absolute_url(sitemap_url, loc_tag.text)
for loc_tag in soup.find_all("loc")
]
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
# the given url doesn't look like a sitemap, let's try to find one
urls = list_pages_for_site(sitemap_url)
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
# the given url doesn't look like a sitemap, let's try to find one
urls = list_pages_for_site(sitemap_url)
if len(urls) == 0:
raise ValueError(
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
if len(urls) == 0:
raise ValueError(
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
)
return urls
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch sitemap from {sitemap_url}: {e}")
except ValueError as e:
raise RuntimeError(f"Error processing sitemap {sitemap_url}: {e}")
except Exception as e:
raise RuntimeError(
f"Unexpected error while processing sitemap {sitemap_url}: {e}"
)
return urls
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
if not urlparse(maybe_relative_url).netloc:
@@ -225,10 +240,14 @@ class WebConnector(LoadConnector):
web_connector_type: str = WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value,
mintlify_cleanup: bool = True, # Mostly ok to apply to other websites as well
batch_size: int = INDEX_BATCH_SIZE,
scroll_before_scraping: bool = False,
**kwargs: Any,
) -> None:
self.mintlify_cleanup = mintlify_cleanup
self.batch_size = batch_size
self.recursive = False
self.scroll_before_scraping = scroll_before_scraping
self.web_connector_type = web_connector_type
if web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value:
self.recursive = True
@@ -344,6 +363,18 @@ class WebConnector(LoadConnector):
continue
visited_links.add(current_url)
if self.scroll_before_scraping:
scroll_attempts = 0
previous_height = page.evaluate("document.body.scrollHeight")
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
page.wait_for_load_state("networkidle", timeout=30000)
new_height = page.evaluate("document.body.scrollHeight")
if new_height == previous_height:
break # Stop scrolling when no more content is loaded
previous_height = new_height
scroll_attempts += 1
content = page.content()
soup = BeautifulSoup(content, "html.parser")
@@ -402,6 +433,56 @@ class WebConnector(LoadConnector):
raise RuntimeError(last_error)
raise RuntimeError("No valid pages found.")
def validate_connector_settings(self) -> None:
# Make sure we have at least one valid URL to check
if not self.to_visit_list:
raise ConnectorValidationError(
"No URL configured. Please provide at least one valid URL."
)
if (
self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value
or self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value
):
return None
# We'll just test the first URL for connectivity and correctness
test_url = self.to_visit_list[0]
# Check that the URL is allowed and well-formed
try:
protected_url_check(test_url)
except ValueError as e:
raise ConnectorValidationError(
f"Protected URL check failed for '{test_url}': {e}"
)
except ConnectionError as e:
# Typically DNS or other network issues
raise ConnectorValidationError(str(e))
# Make a quick request to see if we get a valid response
try:
check_internet_connection(test_url)
except Exception as e:
err_str = str(e)
if "401" in err_str:
raise CredentialExpiredError(
f"Unauthorized access to '{test_url}': {e}"
)
elif "403" in err_str:
raise InsufficientPermissionsError(
f"Forbidden access to '{test_url}': {e}"
)
elif "404" in err_str:
raise ConnectorValidationError(f"Page not found for '{test_url}': {e}")
elif "Max retries exceeded" in err_str and "NameResolutionError" in err_str:
raise ConnectorValidationError(
f"Unable to resolve hostname for '{test_url}'. Please check the URL and your internet connection."
)
else:
# Could be a 5xx or another error, treat as unexpected
raise UnexpectedError(f"Unexpected error validating '{test_url}': {e}")
if __name__ == "__main__":
connector = WebConnector("https://docs.onyx.app/")

View File

@@ -1,9 +1,14 @@
import os
import tempfile
import urllib.parse
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
from zulip import Client
@@ -36,8 +41,39 @@ class ZulipConnector(LoadConnector, PollConnector):
) -> None:
self.batch_size = batch_size
self.realm_name = realm_name
self.realm_url = realm_url if realm_url.endswith("/") else realm_url + "/"
self.client: Client | None = None
# Clean and normalize the URL
realm_url = realm_url.strip().lower()
# Remove any trailing slashes
realm_url = realm_url.rstrip("/")
# Ensure the URL has a scheme
if not realm_url.startswith(("http://", "https://")):
realm_url = f"https://{realm_url}"
try:
parsed = urllib.parse.urlparse(realm_url)
# Extract the base domain without any paths or ports
netloc = parsed.netloc.split(":")[0] # Remove port if present
if not netloc:
raise ValueError(
f"Invalid realm URL format: {realm_url}. "
f"URL must include a valid domain name."
)
# Always use HTTPS for security
self.base_url = f"https://{netloc}"
self.client: Client | None = None
except Exception as e:
raise ValueError(
f"Failed to parse Zulip realm URL: {realm_url}. "
f"Please provide a URL in the format: domain.com or https://domain.com. "
f"Error: {str(e)}"
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
contents = credentials["zuliprc_content"]
@@ -55,12 +91,17 @@ class ZulipConnector(LoadConnector, PollConnector):
return None
def _message_to_narrow_link(self, m: Message) -> str:
stream_name = m.display_recipient # assume str
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
topic_operand = encode_zulip_narrow_operand(m.subject)
try:
stream_name = m.display_recipient # assume str
stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}")
topic_operand = encode_zulip_narrow_operand(m.subject)
narrow_link = f"{self.realm_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
return narrow_link
narrow_link = f"{self.base_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}"
return narrow_link
except Exception as e:
logger.error(f"Error generating Zulip message link: {e}")
# Fallback to a basic link that at least includes the base URL
return f"{self.base_url}#narrow/id/{m.id}"
def _get_message_batch(self, anchor: str) -> Tuple[bool, List[Message]]:
if self.client is None:
@@ -83,6 +124,40 @@ class ZulipConnector(LoadConnector, PollConnector):
def _message_to_doc(self, message: Message) -> Document:
text = f"{message.sender_full_name}: {message.content}"
try:
# Convert timestamps to UTC datetime objects
post_time = datetime.fromtimestamp(message.timestamp, tz=timezone.utc)
edit_time = (
datetime.fromtimestamp(message.last_edit_timestamp, tz=timezone.utc)
if message.last_edit_timestamp is not None
else None
)
# Use the most recent edit time if available, otherwise use post time
doc_time = edit_time if edit_time is not None else post_time
except (ValueError, TypeError) as e:
logger.warning(f"Failed to parse timestamp for message {message.id}: {e}")
post_time = None
edit_time = None
doc_time = None
metadata: Dict[str, Union[str, List[str]]] = {
"stream_name": str(message.display_recipient),
"topic": str(message.subject),
"sender_name": str(message.sender_full_name),
"sender_email": str(message.sender_email),
"message_timestamp": str(message.timestamp),
"message_id": str(message.id),
"stream_id": str(message.stream_id),
"has_reactions": str(len(message.reactions) > 0),
"content_type": str(message.content_type or "text"),
}
# Always include edit timestamp in metadata when available
if edit_time is not None:
metadata["edit_timestamp"] = str(message.last_edit_timestamp)
return Document(
id=f"{message.stream_id}__{message.id}",
sections=[
@@ -92,8 +167,9 @@ class ZulipConnector(LoadConnector, PollConnector):
)
],
source=DocumentSource.ZULIP,
semantic_identifier=message.display_recipient or message.subject,
metadata={},
semantic_identifier=f"{message.display_recipient} > {message.subject}",
metadata=metadata,
doc_updated_at=doc_time, # Use most recent edit time or post time
)
def _get_docs(

View File

@@ -1,6 +1,7 @@
from typing import Any
from typing import List
from typing import Optional
from typing import Union
from pydantic import BaseModel
from pydantic import Field
@@ -19,7 +20,7 @@ class Message(BaseModel):
sender_realm_str: str
subject: str
topic_links: Optional[List[Any]] = None
last_edit_timestamp: Optional[int]
last_edit_timestamp: Optional[int] = None
edit_history: Any = None
reactions: List[Any]
submessages: List[Any]
@@ -39,5 +40,5 @@ class GetMessagesResponse(BaseModel):
found_oldest: Optional[bool] = None
found_newest: Optional[bool] = None
history_limited: Optional[bool] = None
anchor: Optional[str] = None
anchor: Optional[Union[str, int]] = None
messages: List[Message] = Field(default_factory=list)

Some files were not shown because too many files have changed in this diff Show More