Compare commits

..

1 Commits

Author SHA1 Message Date
pablodanswer
675aa517f5 k 2025-02-15 10:58:33 -08:00
376 changed files with 4600 additions and 13404 deletions

View File

@@ -53,90 +53,24 @@ jobs:
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@0.29.0
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# sarif_file: trivy-results.sarif

View File

@@ -1,6 +1,6 @@
name: Run Playwright Tests
name: Run Chromatic Tests
concurrency:
group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
group: Run-Chromatic-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
@@ -198,47 +198,43 @@ jobs:
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# NOTE: Chromatic UI diff testing is currently disabled.
# We are using Playwright for local and CI testing without visual regression checks.
# Chromatic may be reintroduced in the future for UI diff testing if needed.
chromatic-tests:
name: Chromatic Tests
# chromatic-tests:
# name: Chromatic Tests
needs: playwright-tests
runs-on:
[
runs-on,
runner=32cpu-linux-x64,
disk=large,
"run-id=${{ github.run_id }}",
]
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
# needs: playwright-tests
# runs-on:
# [
# runs-on,
# runner=32cpu-linux-x64,
# disk=large,
# "run-id=${{ github.run_id }}",
# ]
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
# with:
# fetch-depth: 0
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
# - name: Setup node
# uses: actions/setup-node@v4
# with:
# node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
# - name: Install node dependencies
# working-directory: ./web
# run: npm ci
- name: Download Playwright test results
uses: actions/download-artifact@v4
with:
name: test-results
path: ./web/test-results
# - name: Download Playwright test results
# uses: actions/download-artifact@v4
# with:
# name: test-results
# path: ./web/test-results
# - name: Run Chromatic
# uses: chromaui/action@latest
# with:
# playwright: true
# projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
# workingDir: ./web
# env:
# CHROMATIC_ARCHIVE_LOCATION: ./test-results
- name: Run Chromatic
uses: chromaui/action@latest
with:
playwright: true
projectToken: ${{ secrets.CHROMATIC_PROJECT_TOKEN }}
workingDir: ./web
env:
CHROMATIC_ARCHIVE_LOCATION: ./test-results

View File

@@ -99,7 +99,7 @@ jobs:
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
DEV_MODE=true \
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack up -d
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack up -d
id: start_docker_multi_tenant
# In practice, `cloud` Auth type would require OAUTH credentials to be set.
@@ -108,13 +108,12 @@ jobs:
echo "Waiting for 3 minutes to ensure API server is ready..."
sleep 180
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -144,28 +143,24 @@ jobs:
- name: Stop multi-tenant Docker containers
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
docker compose -f docker-compose.multitenant-dev.yml -p danswer-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
docker compose -f docker-compose.dev.yml -p danswer-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
docker logs -f danswer-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
@@ -195,24 +190,15 @@ jobs:
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
docker run --rm --network danswer-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
@@ -222,8 +208,6 @@ jobs:
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
@@ -245,13 +229,13 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
docker compose -f docker-compose.dev.yml -p danswer-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
@@ -265,4 +249,4 @@ jobs:
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v
docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -74,9 +74,7 @@ jobs:
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -1,29 +1,18 @@
name: Model Server Tests
name: Connector Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
workflow_dispatch:
inputs:
branch:
description: 'Branch to run the workflow on'
required: false
default: 'main'
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
jobs:
model-check:
@@ -37,23 +26,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Model Server Docker image
run: |
docker pull onyxdotapp/onyx-model-server:latest
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
- name: Set up Python
uses: actions/setup-python@v5
with:
@@ -69,49 +41,6 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
@@ -127,23 +56,3 @@ jobs:
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v

View File

@@ -205,7 +205,7 @@
"--loglevel=INFO",
"--hostname=light@%n",
"-Q",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert,checkpoint_cleanup",
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
],
"presentation": {
"group": "2",

124
README.md
View File

@@ -24,93 +24,113 @@
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
configuring AI Assistants.
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
supported?" or "Where's the pull request for feature Y?"
<h3>Feature Highlights</h3>
<h3>Usage</h3>
**Deep research over your team's knowledge:**
Onyx Web App:
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
**Use Onyx as a secure AI Chat with any LLM:**
![Onyx Chat Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxChatSilentDemo.gif)
**Easily set up connectors to your apps:**
![Onyx Connector Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxConnectorSilentDemo.gif)
**Access Onyx where your team already works:**
![Onyx Bot Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxBot.png)
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
For more details on the Admin UI to manage connectors and users, check out our
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
## Deployment
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
We also have built-in support for high-availability/scalable deployment on Kubernetes.
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
## 💃 Main Features
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
- Chat UI with the ability to select documents to chat with.
- Create custom AI Assistants with different prompts and backing knowledge sets.
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
- Document Search + AI Answers for natural language queries.
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
- Slack integration to get answers and search results directly in Slack.
## 🚧 Roadmap
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
- Chat/Prompt sharing with specific teammates and user groups.
- Multimodal model support, chat with images, video etc.
- Choosing between LLMs and parameters during chat session.
- Tool calling and agent configurations options.
- Organizational understanding and ability to locate and suggest experts from your team.
## Other Notable Benefits of Onyx
- User Authentication with document level access management.
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
- Admin Dashboard to configure connectors, document-sets, access, etc.
- Custom deep learning models + learn from user feedback.
- Easy deployment and ability to host Onyx anywhere of your choosing.
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:
Efficiently pulls the latest changes from:
- Slack
- GitHub
- Google Drive
- Confluence
- Slack
- Gmail
- Salesforce
- Microsoft Sharepoint
- Github
- Jira
- Zendesk
- Gmail
- Notion
- Gong
- Microsoft Teams
- Dropbox
- Slab
- Linear
- Productboard
- Guru
- Bookstack
- Document360
- Sharepoint
- Hubspot
- Local Files
- Websites
- And more ...
See the full list [here](https://docs.onyx.app/connectors).
## 📚 Editions
## 📚 Licensing
There are two editions of Onyx:
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
For feature details, check out [our website](https://www.onyx.app/pricing).
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
- Single Sign-On (SSO), with support for both SAML and OIDC
- Role-based access control
- Document permission inheritance from connected sources
- Usage analytics and query history accessible to admins
- Whitelabeling
- API key authentication
- Encryption of secrets
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Onyx Enterprise Edition:
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)

View File

@@ -28,11 +28,11 @@ RUN apt-get update && \
curl \
zip \
ca-certificates \
libgnutls30 \
libblkid1 \
libmount1 \
libsmartcols1 \
libuuid1 \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 \
libxmlsec1-dev \
pkg-config \
gcc \

View File

@@ -1,27 +0,0 @@
"""Add indexes to document__tag
Revision ID: 1a03d2c2856b
Revises: 9c00a2bccb83
Create Date: 2025-02-18 10:45:13.957807
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1a03d2c2856b"
down_revision = "9c00a2bccb83"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
op.f("ix_document__tag_tag_id"),
"document__tag",
["tag_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")

View File

@@ -1,84 +0,0 @@
"""improved index
Revision ID: 3bd4c84fe72f
Revises: 8f43500ee275
Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
down_revision = "8f43500ee275"
branch_labels = None
depends_on = None
# NOTE:
# This migration addresses issues with the previous migration (8f43500ee275) which caused
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)
def downgrade() -> None:
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -1,32 +0,0 @@
"""add index
Revision ID: 8f43500ee275
Revises: da42808081e3
Create Date: 2025-02-24 17:35:33.072714
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8f43500ee275"
down_revision = "da42808081e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
# op.execute(
# """
# CREATE INDEX idx_chat_message_message_lower
# ON chat_message (LOWER(substring(message, 1, 1500)))
# """
# )
pass
def downgrade() -> None:
# Drop the index
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -1,43 +0,0 @@
"""chat_message_agentic
Revision ID: 9c00a2bccb83
Revises: b7a7eee5aa15
Create Date: 2025-02-17 11:15:43.081150
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c00a2bccb83"
down_revision = "b7a7eee5aa15"
branch_labels = None
depends_on = None
def upgrade() -> None:
# First add the column as nullable
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
# Update existing rows based on presence of SubQuestions
op.execute(
"""
UPDATE chat_message
SET is_agentic = EXISTS (
SELECT 1
FROM agent__sub_question
WHERE agent__sub_question.primary_question_id = chat_message.id
)
WHERE is_agentic IS NULL
"""
)
# Make the column non-nullable with a default value of False
op.alter_column(
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
)
def downgrade() -> None:
op.drop_column("chat_message", "is_agentic")

View File

@@ -1,29 +0,0 @@
"""remove inactive ccpair status on downgrade
Revision ID: acaab4ef4507
Revises: b388730a2899
Create Date: 2025-02-16 18:21:41.330212
"""
from alembic import op
from onyx.db.models import ConnectorCredentialPair
from onyx.db.enums import ConnectorCredentialPairStatus
from sqlalchemy import update
# revision identifiers, used by Alembic.
revision = "acaab4ef4507"
down_revision = "b388730a2899"
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
op.execute(
update(ConnectorCredentialPair)
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
.values(status=ConnectorCredentialPairStatus.ACTIVE)
)

View File

@@ -1,31 +0,0 @@
"""nullable preferences
Revision ID: b388730a2899
Revises: 1a03d2c2856b
Create Date: 2025-02-17 18:49:22.643902
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b388730a2899"
down_revision = "1a03d2c2856b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("user", "temperature_override_enabled", nullable=True)
op.alter_column("user", "auto_scroll", nullable=True)
def downgrade() -> None:
# Ensure no null values before making columns non-nullable
op.execute(
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
)
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
op.alter_column("user", "temperature_override_enabled", nullable=False)
op.alter_column("user", "auto_scroll", nullable=False)

View File

@@ -1,124 +0,0 @@
"""Add checkpointing/failure handling
Revision ID: b7a7eee5aa15
Revises: f39c5794c10a
Create Date: 2025-01-24 15:17:36.763172
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "b7a7eee5aa15"
down_revision = "f39c5794c10a"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"index_attempt",
sa.Column("checkpoint_pointer", sa.String(), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True),
)
op.add_column(
"index_attempt",
sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True),
)
op.create_index(
"ix_index_attempt_cc_pair_settings_poll",
"index_attempt",
[
"connector_credential_pair_id",
"search_settings_id",
"status",
sa.text("time_updated DESC"),
],
)
# Drop the old IndexAttemptError table
op.drop_index("index_attempt_id", table_name="index_attempt_errors")
op.drop_table("index_attempt_errors")
# Create the new version of the table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=False),
sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=True),
sa.Column("document_link", sa.String(), nullable=True),
sa.Column("entity_id", sa.String(), nullable=True),
sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True),
sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True),
sa.Column("failure_message", sa.Text(), nullable=False),
sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
sa.ForeignKeyConstraint(
["connector_credential_pair_id"],
["connector_credential_pair.id"],
),
)
def downgrade() -> None:
op.execute("SET lock_timeout = '5s'")
# try a few times to drop the table, this has been observed to fail due to other locks
# blocking the drop
NUM_TRIES = 10
for i in range(NUM_TRIES):
try:
op.drop_table("index_attempt_errors")
break
except Exception as e:
if i == NUM_TRIES - 1:
raise e
print(f"Error dropping table: {e}. Retrying...")
op.execute("SET lock_timeout = DEFAULT")
# Recreate the old IndexAttemptError table
op.create_table(
"index_attempt_errors",
sa.Column("id", sa.Integer(), primary_key=True),
sa.Column("index_attempt_id", sa.Integer(), nullable=True),
sa.Column("batch", sa.Integer(), nullable=True),
sa.Column("doc_summaries", postgresql.JSONB(), nullable=False),
sa.Column("error_msg", sa.Text(), nullable=True),
sa.Column("traceback", sa.Text(), nullable=True),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
),
sa.ForeignKeyConstraint(
["index_attempt_id"],
["index_attempt.id"],
),
)
op.create_index(
"index_attempt_id",
"index_attempt_errors",
["time_created"],
)
op.drop_index("ix_index_attempt_cc_pair_settings_poll")
op.drop_column("index_attempt", "checkpoint_pointer")
op.drop_column("index_attempt", "poll_range_start")
op.drop_column("index_attempt", "poll_range_end")

View File

@@ -1,120 +0,0 @@
"""migrate jira connectors to new format
Revision ID: da42808081e3
Revises: f13db29f3101
Create Date: 2025-02-24 11:24:54.396040
"""
from alembic import op
import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.
revision = "da42808081e3"
down_revision = "f13db29f3101"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config
for connector_id, old_config in jira_connectors:
if not old_config:
continue
# Extract project key from URL if it exists
new_config: dict[str, str | None] = {}
if project_url := old_config.get("jira_project_url"):
# Parse the URL to get base and project
try:
jira_base, project_key = extract_jira_project(project_url)
new_config = {"jira_base_url": jira_base, "project_key": project_key}
except ValueError:
# If URL parsing fails, just use the URL as the base
new_config = {
"jira_base_url": project_url.split("/projects/")[0],
"project_key": None,
}
else:
# For connectors without a project URL, we need admin intervention
# Mark these for review
print(
f"WARNING: Jira connector {connector_id} has no project URL configured"
)
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :id
"""
),
{"id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config back to the old format
for connector_id, new_config in jira_connectors:
if not new_config:
continue
old_config = {}
base_url = new_config.get("jira_base_url")
project_key = new_config.get("project_key")
if base_url and project_key:
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
elif base_url:
old_config = {"jira_project_url": base_url}
else:
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :old_config
WHERE id = :id
"""
),
{"id": connector_id, "old_config": old_config},
)

View File

@@ -1,36 +0,0 @@
"""force lowercase all users
Revision ID: f11b408e39d3
Revises: 3bd4c84fe72f
Create Date: 2025-02-26 17:04:55.683500
"""
# revision identifiers, used by Alembic.
revision = "f11b408e39d3"
down_revision = "3bd4c84fe72f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing user emails to lowercase
from alembic import op
op.execute(
"""
UPDATE "user"
SET email = LOWER(email)
"""
)
# 2) Add a check constraint to ensure emails are always lowercase
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
def downgrade() -> None:
# Drop the check constraint
from alembic import op
op.drop_constraint("ensure_lowercase_email", "user", type_="check")

View File

@@ -1,27 +0,0 @@
"""Add composite index for last_modified and last_synced to document
Revision ID: f13db29f3101
Revises: b388730a2899
Create Date: 2025-02-18 22:48:11.511389
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f13db29f3101"
down_revision = "acaab4ef4507"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
op.create_index(
"ix_document_sync_status",
"document",
["last_modified", "last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index("ix_document_sync_status", table_name="document")

View File

@@ -1,42 +0,0 @@
"""lowercase multi-tenant user auth
Revision ID: 34e3630c7f32
Revises: a4f6ee863c47
Create Date: 2025-02-26 15:03:01.211894
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "34e3630c7f32"
down_revision = "a4f6ee863c47"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing rows to lowercase
op.execute(
"""
UPDATE user_tenant_mapping
SET email = LOWER(email)
"""
)
# 2) Add a check constraint so that emails cannot be written in uppercase
op.create_check_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
"email = LOWER(email)",
schema="public",
)
def downgrade() -> None:
# Drop the check constraint
op.drop_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
schema="public",
type_="check",
)

View File

@@ -5,9 +5,11 @@ from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -16,8 +18,10 @@ logger = setup_logger()
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@@ -31,19 +35,24 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str) -> None:
def check_ttl_management_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@celery_app.task(
@@ -51,9 +60,9 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,

View File

@@ -18,7 +18,7 @@ logger = setup_logger()
def monitor_usergroup_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
"""This function is likely to move in the worker refactor happening next."""
fence_key = key_bytes.decode("utf-8")

View File

@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import UserGroup__ConnectorCredentialPair
@@ -36,11 +35,10 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
def get_cc_pairs_by_source(
db_session: Session,
source_type: DocumentSource,
access_type: AccessType | None = None,
status: ConnectorCredentialPairStatus | None = None,
only_sync: bool,
) -> list[ConnectorCredentialPair]:
"""
Get all cc_pairs for a given source type with optional filtering by access_type and status
Get all cc_pairs for a given source type (and optionally only sync)
result is sorted by cc_pair id
"""
query = (
@@ -50,11 +48,8 @@ def get_cc_pairs_by_source(
.order_by(ConnectorCredentialPair.id)
)
if access_type is not None:
query = query.filter(ConnectorCredentialPair.access_type == access_type)
if status is not None:
query = query.filter(ConnectorCredentialPair.status == status)
if only_sync:
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
cc_pairs = query.all()
return cc_pairs

View File

@@ -14,24 +14,30 @@ def _build_group_member_email_map(
confluence_client: OnyxConfluence, cc_pair_id: int
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user}")
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
email = user.email
user = user_result.get("user", {})
if not user:
msg = f"user result missing user field: {user_result}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
continue
email = user.get("email")
if not email:
# This field is only present in Confluence Server
user_name = user.username
user_name = user.get("username")
# If it is present, try to get the email using a Server-specific method
if user_name:
email = get_user_email_from_username__server(
confluence_client=confluence_client,
user_name=user_name,
)
if not email:
# If we still don't have an email, skip this user
msg = f"user result missing email field: {user}"
if user.type == "app":
msg = f"user result missing email field: {user_result}"
if user.get("type") == "app":
logger.warning(msg)
else:
emit_background_error(msg, cc_pair_id=cc_pair_id)
@@ -39,7 +45,7 @@ def _build_group_member_email_map(
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)

View File

@@ -62,14 +62,12 @@ def _fetch_permissions_for_permission_ids(
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain)",
supportsAllDrives=True,
continue_on_404_or_403=True,
)
permissions_for_doc_id = []
@@ -106,13 +104,7 @@ def _get_permissions_from_slim_doc(
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
skipped_permissions = 0
for permission in permissions_list:
if not permission:
skipped_permissions += 1
continue
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
@@ -129,11 +121,6 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
if skipped_permissions > 0:
logger.warning(
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
)
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())

View File

@@ -5,7 +5,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -17,7 +17,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config)
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
slim_doc_generator = slack_connector.retrieve_all_slim_documents(callback=callback)

View File

@@ -33,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
return await call_next(request)
except Exception as e:
logger.exception(f"Error in tenant ID middleware: {str(e)}")
logger.error(f"Error in tenant ID middleware: {str(e)}")
raise
@@ -49,7 +49,7 @@ async def _get_tenant_id_from_request(
"""
# Check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
if tenant_id:
return tenant_id
# Check for anonymous user cookie

View File

@@ -36,12 +36,12 @@ from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -271,12 +271,12 @@ def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
tenant_id = get_current_tenant_id()
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
@@ -329,6 +329,7 @@ def handle_slack_oauth_callback(
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
@@ -336,7 +337,7 @@ def handle_slack_oauth_callback(
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
@@ -522,6 +523,7 @@ def handle_google_drive_oauth_callback(
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
@@ -529,7 +531,7 @@ def handle_google_drive_oauth_callback(
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (

View File

@@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def _check_token_rate_limits(user: User | None) -> None:
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)
elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)
else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
(_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, (tenant_id,)),
]
)
@@ -52,8 +52,8 @@ User rate limits
"""
def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
@@ -93,8 +93,8 @@ User Group rate limits
"""
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:

View File

@@ -2,7 +2,6 @@ import csv
import io
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from uuid import UUID
from fastapi import APIRouter
@@ -22,10 +21,8 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
@@ -38,8 +35,6 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
def fetch_and_process_chat_session_history(
db_session: Session,
@@ -112,17 +107,6 @@ def get_user_chat_sessions(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
# we specifically don't allow this endpoint if "anonymized" since
# this is a direct query on the user id
if ONYX_QUERY_HISTORY_TYPE in [
QueryHistoryType.DISABLED,
QueryHistoryType.ANONYMIZED,
]:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Per user query history has been disabled by the administrator.",
)
try:
chat_sessions = get_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session, limit=0
@@ -138,7 +122,6 @@ def get_user_chat_sessions(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
@@ -158,12 +141,6 @@ def get_chat_session_history(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
@@ -180,16 +157,11 @@ def get_chat_session_history(
feedback_filter=feedback_type,
)
minimal_chat_sessions: list[ChatSessionMinimal] = []
for chat_session in page_of_chat_sessions:
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
minimal_chat_sessions.append(minimal_chat_session)
return PaginatedReturn(
items=minimal_chat_sessions,
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
total_items=total_filtered_chat_sessions_count,
)
@@ -200,12 +172,6 @@ def get_chat_session_admin(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionSnapshot:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
try:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
@@ -227,9 +193,6 @@ def get_chat_session_admin(
f"Could not create snapshot for chat session with id '{chat_session_id}'",
)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
return snapshot
@@ -240,12 +203,6 @@ def get_query_history_as_csv(
end: datetime | None = None,
db_session: Session = Depends(get_session),
) -> StreamingResponse:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
@@ -256,9 +213,6 @@ def get_query_history_as_csv(
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
for chat_session_snapshot in complete_chat_session_history:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
question_answer_pairs.extend(
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
)

View File

@@ -41,15 +41,14 @@ from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
@@ -58,14 +57,13 @@ router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
tenant_id: str | None = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@@ -74,15 +72,15 @@ async def get_anonymous_user_path_api(
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
tenant_id: str = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
@@ -103,7 +101,7 @@ async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
@@ -152,17 +150,14 @@ async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
try:
# Fetch tenant_id and current tenant's information
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
@@ -186,8 +181,6 @@ async def create_subscription_session(
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
@@ -204,7 +197,7 @@ async def impersonate_user(
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
with get_session_with_tenant(tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
@@ -228,9 +221,8 @@ async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"

View File

@@ -7,7 +7,6 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@@ -42,9 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(
tenant_id: str,
) -> BillingInformation | SubscriptionStatusResponse:
def fetch_billing_information(tenant_id: str) -> BillingInformation:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@@ -55,19 +52,8 @@ def fetch_billing_information(
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
response_data = response.json()
# Check if the response indicates no subscription
if (
isinstance(response_data, dict)
and "subscribed" in response_data
and not response_data["subscribed"]
):
return SubscriptionStatusResponse(**response_data)
# Otherwise, parse as BillingInformation
return BillingInformation(**response_data)
billing_info = BillingInformation(**response.json())
return billing_info
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:

View File

@@ -118,7 +118,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
@@ -134,7 +134,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
@@ -200,35 +200,14 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
def configure_default_api_keys(db_session: Session) -> None:
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
default_model_name="gpt-4",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
@@ -240,6 +219,25 @@ def configure_default_api_keys(db_session: Session) -> None:
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,

View File

@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()

View File

@@ -98,17 +98,12 @@ class CloudEmbedding:
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
f"Error embedding text with OpenAI: {str(e)} \n"
f"Model: {model} \n"
f"Provider: {self.provider} \n"
f"Texts: {texts}"
)
logger.error(error_string)
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(

View File

@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
)
graph.add_node(
node="choose_tool",
action=choose_tool,
node="llm_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="call_tool",
action=call_tool,
node="tool_call",
action=tool_call,
)
graph.add_node(
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
graph.add_edge(
start_key="call_tool",
start_key="tool_call",
end_key="basic_use_tool_response",
)
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
else "call_tool"
else "tool_call"
)

View File

@@ -31,14 +31,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -87,11 +85,9 @@ def check_sub_answer(
agent_error: AgentErrorLog | None = None
response: BaseMessage | None = None
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
fast_llm.invoke,
response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
)
quality_str: str = cast(str, response.content)
@@ -100,7 +96,7 @@ def check_sub_answer(
)
log_result = f"Answer quality: {quality_str}"
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import merge_message_runs
@@ -46,13 +47,11 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -111,14 +110,15 @@ def generate_sub_answer(
config=fast_llm.config,
)
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
response: list[str] = []
def stream_sub_answer() -> list[str]:
agent_error: AgentErrorLog | None = None
try:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -142,15 +142,8 @@ def generate_sub_answer(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
stream_sub_answer,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
@@ -59,15 +60,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -80,7 +77,6 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
_llm_node_error_strings = LLMNodeErrorStrings(
@@ -234,11 +230,7 @@ def generate_initial_answer(
sub_questions = all_sub_questions # Replace the original assignment
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
model = graph_config.tooling.fast_llm
doc_context = format_docs(answer_generation_documents.context_documents)
doc_context = trim_prompt_piece(
@@ -268,16 +260,15 @@ def generate_initial_answer(
)
]
streamed_tokens: list[str] = [""]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
def stream_initial_answer() -> list[str]:
response: list[str] = []
try:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -301,16 +292,9 @@ def generate_initial_answer(
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
streamed_tokens.append(content)
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
stream_initial_answer,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -36,10 +36,7 @@ from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -50,7 +47,6 @@ from onyx.prompts.agent_search import (
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -135,12 +131,10 @@ def decompose_orig_question(
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
dispatch_separated,
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
@@ -160,7 +154,7 @@ def decompose_orig_question(
)
write_custom_event("stream_finished", stop_event, writer)
except (LLMTimeoutError, TimeoutError) as e:
except LLMTimeoutError as e:
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:

View File

@@ -25,7 +25,7 @@ logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
"""
LangGraph edge to route to agent search.
"""
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
):
return "start_agent_search"
else:
return "call_tool"
return "tool_call"
else:
return "logging_node"

View File

@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
# Choose the initial tool
graph.add_node(
node="initial_tool_choice",
action=choose_tool,
action=llm_tool_choice,
)
# Call the tool, if required
graph.add_node(
node="call_tool",
action=call_tool,
node="tool_call",
action=tool_call,
)
# Use the tool response
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_conditional_edges(
"initial_tool_choice",
route_initial_tool_choice,
["call_tool", "start_agent_search", "logging_node"],
["tool_call", "start_agent_search", "logging_node"],
)
graph.add_edge(
start_key="call_tool",
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(

View File

@@ -33,15 +33,13 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -107,14 +105,11 @@ def compare_answers(
refined_answer_improvement: bool | None = None
# no need to stream this
try:
resp = run_with_timeout(
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
resp = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -44,10 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -56,7 +53,6 @@ from onyx.prompts.agent_search import (
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -138,17 +134,15 @@ def create_refined_sub_questions(
agent_error: AgentErrorLog | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
dispatch_separated,
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -22,17 +22,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
@@ -90,42 +84,30 @@ def extract_entities_terms(
]
fast_llm = graph_config.tooling.fast_llm
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
llm_response = run_with_timeout(
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - extract entities terms")
except ValueError:
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - extract entities terms")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
retrieved_entities_relationships=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
)
return EntityTermExtractionUpdate(

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
@@ -65,21 +66,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -98,7 +92,6 @@ from onyx.prompts.agent_search import (
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -260,12 +253,7 @@ def generate_validate_refined_answer(
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
model = graph_config.tooling.fast_llm
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
relevant_docs_str = trim_prompt_piece(
model.config,
@@ -296,13 +284,13 @@ def generate_validate_refined_answer(
)
]
streamed_tokens: list[str] = [""]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
def stream_refined_answer() -> list[str]:
try:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -327,15 +315,8 @@ def generate_validate_refined_answer(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
return streamed_tokens
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
stream_refined_answer,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
@@ -402,20 +383,16 @@ def generate_validate_refined_answer(
)
]
validation_model = graph_config.tooling.fast_llm
try:
validation_response = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
validation_response = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
separator=AGENT_ANSWER_SEPARATOR,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
refined_answer_quality = True
logger.error("LLM Timeout Error - validate refined answer")

View File

@@ -34,16 +34,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -71,7 +69,7 @@ def expand_queries(
node_start_time = datetime.now()
question = state.question
model = graph_config.tooling.fast_llm
llm = graph_config.tooling.fast_llm
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_num = 0, 0
@@ -90,12 +88,10 @@ def expand_queries(
rewritten_queries = []
try:
llm_response_list = run_with_timeout(
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
dispatch_separated,
model.stream(
llm_response_list = dispatch_separated(
llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
@@ -105,7 +101,7 @@ def expand_queries(
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -55,7 +55,6 @@ def rerank_documents(
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
if rerank_settings is None:
with get_session_context_manager() as db_session:
@@ -63,31 +62,23 @@ def rerank_documents(
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
# Initial default: no reranking. Will be overwritten below if reranking is warranted
reranked_documents = verified_documents
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1:
if not allow_agent_reranking:
logger.info("Use of local rerank model without GPU, skipping reranking")
# No reranking, stay with verified_documents as default
else:
# Reranking is warranted, use the rerank_sections functon
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
else:
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
# No reranking, stay with verified_documents as default
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")
# No reranking, stay with verified_documents as default
reranked_documents = verified_documents
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:

View File

@@ -25,15 +25,13 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -88,11 +86,8 @@ def verify_documents(
] # default is to treat document as relevant
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
response = fast_llm.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
)
assert isinstance(response.content, str)
@@ -101,7 +96,7 @@ def verify_documents(
):
verified_documents = []
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
logger.error("LLM Timeout Error - verify documents")

View File

@@ -67,7 +67,6 @@ class GraphSearchConfig(BaseModel):
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
class GraphConfig(BaseModel):

View File

@@ -25,7 +25,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def choose_tool(
def llm_tool_choice(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,

View File

@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
write_custom_event("basic_response", packet, writer)
def call_tool(
def tool_call(
state: ToolChoiceUpdate,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,

View File

@@ -43,9 +43,8 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
@@ -81,7 +80,6 @@ from onyx.tools.tool_implementations.search.search_tool import SearchResponseSum
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
@@ -397,13 +395,11 @@ def summarize_history(
)
try:
history_response = run_with_timeout(
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
llm.invoke,
history_response = llm.invoke(
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
logger.error("LLM Timeout Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
from shared_configs.configs import MULTI_TENANT
_API_KEY_HEADER_NAME = "Authorization"
@@ -36,7 +35,8 @@ class ApiKeyDescriptor(BaseModel):
def generate_api_key(tenant_id: str | None = None) -> str:
if not MULTI_TENANT or not tenant_id:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID

View File

@@ -2,8 +2,6 @@ import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
from email.utils import make_msgid
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
@@ -12,10 +10,8 @@ from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
@@ -153,9 +149,8 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
if mail_from:
msg["From"] = mail_from
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
@@ -177,7 +172,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
subject = "Your Onyx Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>We're sorry to see you go.</p>"
"<p>Were sorry to see you go.</p>"
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
"<p>If you change your mind, you can always come back!</p>"
)
@@ -192,64 +187,36 @@ def send_subscription_cancellation_email(user_email: str) -> None:
send_email(user_email, subject, html_content, text_content)
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Organization"
heading = "You've Been Invited!"
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
elif auth_type == AuthType.BASIC:
message += (
"<p>To join the organization, please click the button below to set a password "
"and complete your registration.</p>"
)
elif auth_type == AuthType.GOOGLE_OAUTH:
message += (
"<p>To join the organization, please click the button below to login with Google "
"and complete your registration.</p>"
)
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
message += (
"<p>To join the organization, please click the button below to"
" complete your registration.</p>"
)
else:
raise ValueError(f"Invalid auth type: {auth_type}")
message = (
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
"You'll be asked to set a password or login with Google to complete your registration."
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)
def send_forgot_password_email(
user_email: str,
token: str,
tenant_id: str,
mail_from: str = EMAIL_FROM,
tenant_id: str | None = None,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
if tenant_id:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)

View File

@@ -42,5 +42,4 @@ def fetch_no_auth_user(
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
is_anonymous_user=anonymous_user_enabled,
password_configured=False,
)

View File

@@ -1,7 +1,5 @@
import json
import random
import secrets
import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
@@ -88,6 +86,7 @@ from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -95,7 +94,6 @@ from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
@@ -105,11 +103,15 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True
@@ -141,30 +143,6 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
return email or ""
def generate_password() -> str:
lowercase_letters = string.ascii_lowercase
uppercase_letters = string.ascii_uppercase
digits = string.digits
special_characters = string.punctuation
# Ensure at least one of each required character type
password = [
secrets.choice(uppercase_letters),
secrets.choice(digits),
secrets.choice(special_characters),
]
# Fill the rest with a mix of characters
remaining_length = 12 - len(password)
all_characters = lowercase_letters + uppercase_letters + digits + special_characters
password.extend(secrets.choice(all_characters) for _ in range(remaining_length))
# Shuffle the password to randomize the position of the required characters
random.shuffle(password)
return "".join(password)
def user_needs_to_be_verified() -> bool:
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
return REQUIRE_EMAIL_VERIFICATION
@@ -214,8 +192,8 @@ def verify_email_is_invited(email: str) -> None:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -411,7 +389,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"refresh_token": refresh_token,
}
user: User | None = None
user: User
try:
# Attempt to get user by OAuth account
@@ -420,20 +398,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = await self.user_db.get_by_email(account_email)
user = await self.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
# Make sure user is not None before adding OAuth account
if user is not None:
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
else:
# This shouldn't happen since get_by_email would raise UserNotExists
# but adding as a safeguard
raise exceptions.UserNotExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
@@ -444,36 +417,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.user_db.create(user_dict)
# Add OAuth account only if user creation was successful
if user is not None:
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
raise HTTPException(
status_code=500, detail="Failed to create user account"
)
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
# User exists, update OAuth account if needed
if user is not None: # Add explicit check
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# Ensure user is not None before proceeding
if user is None:
raise HTTPException(
status_code=500, detail="Failed to authenticate or create user"
)
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
@@ -568,7 +531,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async_return_default_schema,
)(email=user.email)
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
@@ -632,39 +595,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def reset_password_as_admin(self, user_id: uuid.UUID) -> str:
"""Admin-only. Generate a random password for a user and return it."""
user = await self.get(user_id)
new_password = generate_password()
await self._update(user, {"password": new_password})
return new_password
async def change_password_if_old_matches(
self, user: User, old_password: str, new_password: str
) -> None:
"""
For normal users to change password if they know the old one.
Raises 400 if old password doesn't match.
"""
verified, updated_password_hash = self.password_helper.verify_and_update(
old_password, user.hashed_password
)
if not verified:
# Raise some HTTPException (or your custom exception) if old password is invalid:
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid current password",
)
# If the hash was upgraded behind the scenes, we can keep it before setting the new password:
if updated_password_hash:
user.hashed_password = updated_password_hash
# Now apply and validate the new password
await self._update(user, {"password": new_password})
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
@@ -889,9 +819,8 @@ async def current_limited_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> User | None:
tenant_id = get_current_tenant_id()
return await double_check_user(
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
)

View File

@@ -2,7 +2,6 @@ import logging
import multiprocessing
import time
from typing import Any
from typing import cast
import sentry_sdk
from celery import Task
@@ -34,7 +33,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -60,35 +58,13 @@ else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
class TenantAwareTask(Task):
"""A custom base Task that sets tenant_id in a contextvar before running."""
abstract = True # So Celery knows not to register this as a real task.
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Grab tenant_id from the kwargs, or fallback to default if missing.
tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA
# Set the context var
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Actually run the task now
try:
return super().__call__(*args, **kwargs)
finally:
# Clear or reset after the task runs
# so it does not leak into any subsequent tasks on the same worker process
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
**kwds: Any,
) -> None:
pass
@@ -132,9 +108,9 @@ def on_task_postrun(
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = POSTGRES_DEFAULT_SCHEMA
tenant_id = None
else:
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
@@ -225,7 +201,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_shared_redis_client()
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -311,7 +287,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_shared_redis_client()
r = get_redis_client(tenant_id=None)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
@@ -463,6 +439,24 @@ class TenantContextFilter(logging.Filter):
return True
@task_prerun.connect
def set_tenant_id(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
"""Signal handler to set tenant ID in context var before task starts."""
tenant_id = (
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if kwargs
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@task_postrun.connect
def reset_tenant_id(
sender: Any | None = None,

View File

@@ -132,7 +132,6 @@ class DynamicTenantScheduler(PersistentScheduler):
f"Adding options to task {tenant_task_name}: {options}"
)
tenant_task["options"] = options
new_schedule[tenant_task_name] = tenant_task
return new_schedule
@@ -257,4 +256,3 @@ def on_setup_logging(
celery_app.conf.beat_scheduler = DynamicTenantScheduler
celery_app.conf.task_default_base = app_base.TenantAwareTask

View File

@@ -20,7 +20,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.heavy")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -21,7 +21,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.indexing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -23,7 +23,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.light")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -20,7 +20,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -24,7 +24,7 @@ from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_default_tenant
from onyx.db.engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
@@ -38,7 +38,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -47,7 +47,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.primary")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -102,7 +101,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_shared_redis_client()
r = get_redis_client(tenant_id=None)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -159,7 +158,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_current_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
@@ -235,7 +234,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_shared_redis_client()
r = get_redis_client(tenant_id=None)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")

View File

@@ -92,8 +92,7 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue and return them
as a set.
"""This is a redis specific way to build a list of tasks in a queue.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.

View File

@@ -34,7 +34,7 @@ def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
tenant_id: str | None = None,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
tenant_id: str | None = None,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id

View File

@@ -36,15 +36,6 @@ beat_task_templates.extend(
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-checkpoint-cleanup",
"task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
"schedule": timedelta(hours=1),
"options": {
"priority": OnyxCeleryPriority.LOW,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": "check-for-connector-deletion",
"task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,

View File

@@ -8,21 +8,16 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
@@ -32,7 +27,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -57,51 +52,6 @@ class TaskDependencyError(RuntimeError):
with connector deletion."""
def revoke_tasks_blocking_deletion(
redis_connector: RedisConnector, db_session: Session, app: Celery
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
try:
permissions_sync_payload = redis_connector.permissions.payload
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
app.control.revoke(permissions_sync_payload.celery_task_id)
task_logger.info(
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking pruning task")
try:
prune_payload = redis_connector.prune.payload
if prune_payload and prune_payload.celery_task_id:
app.control.revoke(prune_payload.celery_task_id)
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
except Exception:
task_logger.exception("Exception while revoking permissions sync task")
try:
external_group_sync_payload = redis_connector.external_group_sync.payload
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
app.control.revoke(external_group_sync_payload.celery_task_id)
task_logger.info(
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking external group sync task")
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
@@ -109,46 +59,32 @@ def revoke_tasks_blocking_deletion(
trail=False,
bind=True,
)
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# Prevent this task from overlapping with itself
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating connector deletion fences"
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
cc_pair_ids.append(cc_pair.id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
@@ -156,38 +92,9 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# on the first error, we set a stop signal and revoke the dependent tasks
# on subsequent errors, we hard reset blocking fences after our specified timeout
# is exceeded
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
if not redis_connector.stop.fenced:
# one time revoke of celery tasks
task_logger.info("Revoking any tasks blocking deletion.")
revoke_tasks_blocking_deletion(
redis_connector, db_session, self.app
)
redis_connector.stop.set_fence(True)
redis_connector.stop.set_timeout()
else:
# stop signal already set
if redis_connector.stop.timed_out:
# waiting too long, just reset blocking fences
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()
else:
# just wait
pass
redis_connector.stop.set_fence(True)
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
@@ -222,7 +129,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
cc_pair_id: int,
db_session: Session,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
@@ -262,7 +169,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
redis_connector.delete.set_active()
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
@@ -343,7 +249,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
def monitor_connector_deletion_taskset(
tenant_id: str, key_bytes: bytes, r: Redis
tenant_id: str | None, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -371,7 +277,7 @@ def monitor_connector_deletion_taskset(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -381,7 +287,7 @@ def monitor_connector_deletion_taskset(
)
return
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -495,171 +401,3 @@ def monitor_connector_deletion_taskset(
)
redis_connector.delete.reset()
def validate_connector_deletion_fences(
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# validate all existing connector deletion jobs
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
continue
validate_connector_deletion_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
r,
)
lock_beat.reacquire()
return
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.delete.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.delete.payload
except ValidationError:
task_logger.exception(
"validate_connector_deletion_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return
if not payload:
return
# OK, there's actually something for us to validate
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.delete.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_connector_deletion_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.delete.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.delete.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_connector_deletion_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return

View File

@@ -30,7 +30,6 @@ from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -43,12 +42,10 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -66,7 +63,6 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
@@ -123,13 +119,13 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None:
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# TODO(rkuo): merge into check function after lookup table for fences is added
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -144,7 +140,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
try:
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
@@ -193,23 +189,16 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
task_logger.info(f"check_for_doc_permissions_sync finished: tenant={tenant_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id} {error_msg}"
)
task_logger.exception(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id}"
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
@@ -221,7 +210,7 @@ def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Returns a randomized payload id on success.
Returns None if no syncing is required."""
@@ -258,7 +247,7 @@ def try_creating_permissions_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
@@ -293,19 +282,13 @@ def try_creating_permissions_sync_task(
redis_connector.permissions.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_permissions_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_permissions_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -320,7 +303,7 @@ def try_creating_permissions_sync_task(
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
@@ -338,7 +321,7 @@ def connector_permission_sync_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -395,7 +378,7 @@ def connector_permission_sync_generator_task(
return None
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -405,29 +388,6 @@ def connector_permission_sync_generator_task(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -479,10 +439,6 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
)
task_logger.exception(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
@@ -509,7 +465,7 @@ def connector_permission_sync_generator_task(
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str,
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
@@ -517,8 +473,6 @@ def update_external_document_permissions_task(
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@@ -526,8 +480,7 @@ def update_external_document_permissions_task(
external_access = document_external_access.external_access
try:
with get_session_with_current_tenant() as db_session:
# Add the users to the DB if they don't exist
with get_session_with_tenant(tenant_id) as db_session:
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
@@ -558,33 +511,18 @@ def update_external_document_permissions_task(
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
)
except Exception:
task_logger.exception(
f"update_external_document_permissions_task exceptioned: "
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} doc_id={doc_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
finally:
task_logger.info(
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
)
return True
def validate_permission_sync_fences(
tenant_id: str,
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -631,7 +569,7 @@ def validate_permission_sync_fences(
def validate_permission_sync_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
@@ -841,7 +779,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
def monitor_ccpair_permissions_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)

View File

@@ -37,12 +37,9 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -58,7 +55,6 @@ from onyx.redis.redis_connector_ext_group_sync import (
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -123,11 +119,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -144,7 +140,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
try:
cc_pair_ids_to_sync: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
@@ -152,10 +148,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session,
source,
access_type=AccessType.SYNC,
status=ConnectorCredentialPairStatus.ACTIVE,
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
@@ -202,17 +195,12 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_external_group_sync exception: tenant={tenant_id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_external_group_sync finished: tenant={tenant_id}")
return True
@@ -220,7 +208,7 @@ def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
@@ -242,7 +230,7 @@ def try_creating_external_group_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
@@ -279,19 +267,12 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_external_group_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
task_logger.info(
f"try_creating_external_group_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -306,7 +287,7 @@ def try_creating_external_group_sync_task(
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""
External group sync task for a given connector credential pair
@@ -315,7 +296,7 @@ def connector_external_group_sync_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -376,40 +357,16 @@ def connector_external_group_sync_generator_task(
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -421,23 +378,12 @@ def connector_external_group_sync_generator_task(
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
logger.debug(f"New external user groups: {external_user_groups}")
replace_user__ext_group_for_cc_pair(
db_session=db_session,
@@ -458,19 +404,11 @@ def connector_external_group_sync_generator_task(
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id} {error_msg}"
)
task_logger.exception(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
)
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -493,7 +431,7 @@ def connector_external_group_sync_generator_task(
def validate_external_group_sync_fences(
tenant_id: str,
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_replica: Redis,
@@ -521,11 +459,12 @@ def validate_external_group_sync_fences(
)
lock_beat.reacquire()
return
def validate_external_group_sync_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,

View File

@@ -1,10 +1,9 @@
import multiprocessing
import os
import sys
import time
import traceback
from datetime import datetime
from datetime import timezone
from enum import Enum
from http import HTTPStatus
from time import sleep
from typing import Any
@@ -16,7 +15,6 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from celery.result import AsyncResult
from celery.states import READY_STATES
from pydantic import BaseModel
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -28,31 +26,22 @@ from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attem
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
from onyx.background.indexing.checkpointing_utils import (
get_index_attempts_with_old_checkpoints,
)
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_indexing import run_indexing_entrypoint
from onyx.configs.app_configs import MANAGED_VESPA
from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingMode
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
@@ -81,136 +70,8 @@ from shared_configs.configs import SENTRY_DSN
logger = setup_logger()
class IndexingWatchdogTerminalStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
SPAWN_NOT_ALIVE = (
"spawn_not_alive" # spawn succeeded but process did not come alive
)
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
FENCE_NOT_FOUND = "fence_not_found" # fence does not exist
FENCE_READINESS_TIMEOUT = (
"fence_readiness_timeout" # fence exists but wasn't ready within the timeout
)
FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch
TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already
INDEX_ATTEMPT_MISMATCH = (
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_VALIDATION_ERROR = (
"connector_validation_error" # the connector validation failed
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
# the watchdog received a termination signal
TERMINATED_BY_SIGNAL = "terminated_by_signal"
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
# consolidate once we know more
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@property
def code(self) -> int:
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251,
IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252,
IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253,
IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254,
IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255,
}
return _ENUM_TO_CODE[self]
@classmethod
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT,
252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH,
253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING,
254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH,
255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED,
}
if code in _CODE_TO_ENUM:
return _CODE_TO_ENUM[code]
return IndexingWatchdogTerminalStatus.UNDEFINED
class SimpleJobResult:
"""The data we want to have when the watchdog finishes"""
def __init__(self) -> None:
self.status = IndexingWatchdogTerminalStatus.UNDEFINED
self.connector_source = None
self.exit_code = None
self.exception_str = None
status: IndexingWatchdogTerminalStatus
connector_source: str | None
exit_code: int | None
exception_str: str | None
class ConnectorIndexingContext(BaseModel):
tenant_id: str
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
class ConnectorIndexingLogBuilder:
def __init__(self, ctx: ConnectorIndexingContext):
self.ctx = ctx
def build(self, msg: str, **kwargs: Any) -> str:
msg_final = (
f"{msg}: "
f"tenant_id={self.ctx.tenant_id} "
f"attempt={self.ctx.index_attempt_id} "
f"cc_pair={self.ctx.cc_pair_id} "
f"search_settings={self.ctx.search_settings_id}"
)
# Append extra keyword arguments in logfmt style
if kwargs:
extra_logfmt = " ".join(f"{key}={value}" for key, value in kwargs.items())
msg_final = f"{msg_final} {extra_logfmt}"
return msg_final
def monitor_ccpair_indexing_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
@@ -358,16 +219,15 @@ def monitor_ccpair_indexing_taskset(
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
tasks_created = 0
locked = False
redis_client = get_redis_client()
redis_client_replica = get_redis_replica_client()
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
@@ -405,7 +265,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# 1/3: KICKOFF
# check for search settings swap
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
@@ -426,7 +286,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# gather cc_pair_ids
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
@@ -436,7 +296,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
@@ -514,7 +374,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
@@ -566,7 +426,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(
tenant_id, key_bytes, redis_client_replica, db_session
)
@@ -597,8 +457,8 @@ def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
tenant_id: str,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -636,6 +496,7 @@ def connector_indexing_task(
f"search_settings={search_settings_id}"
)
attempt_found = False
n_final_progress: int | None = None
# 20 is the documented default for httpx max_keepalive_connections
@@ -649,24 +510,22 @@ def connector_indexing_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if redis_connector.delete.fenced:
raise SimpleJobException(
raise RuntimeError(
f"Indexing will not start because connector deletion is in progress: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.delete.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code,
f"fence={redis_connector.delete.fence_key}"
)
if redis_connector.stop.fenced:
raise SimpleJobException(
raise RuntimeError(
f"Indexing will not start because a connector stop signal was detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"fence={redis_connector.stop.fence_key}",
code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code,
f"fence={redis_connector.stop.fence_key}"
)
# this wait is needed to avoid a race condition where
@@ -675,24 +534,19 @@ def connector_indexing_task(
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise SimpleJobException(
raise ValueError(
f"connector_indexing_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.permissions.fence_key}",
code=IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT.code,
f"fence={redis_connector.permissions.fence_key}"
)
if not redis_connector_index.fenced: # The fence must exist
raise SimpleJobException(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
raise ValueError(
f"connector_indexing_task - fence not found: fence={redis_connector_index.fence_key}"
)
payload = redis_connector_index.payload # The payload must exist
if not payload:
raise SimpleJobException(
"connector_indexing_task: payload invalid or not found",
code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code,
)
raise ValueError("connector_indexing_task: payload invalid or not found")
if payload.index_attempt_id is None or payload.celery_task_id is None:
logger.info(
@@ -702,11 +556,10 @@ def connector_indexing_task(
continue
if payload.index_attempt_id != index_attempt_id:
raise SimpleJobException(
raise ValueError(
f"connector_indexing_task - id mismatch. Task may be left over from previous run.: "
f"task_index_attempt={index_attempt_id} "
f"payload_index_attempt={payload.index_attempt_id}",
code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code,
f"payload_index_attempt={payload.index_attempt_id}"
)
logger.info(
@@ -730,26 +583,19 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise SimpleJobException(
f"Indexing task already running, exiting...: "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING.code,
)
return None
payload.started = datetime.now(timezone.utc)
redis_connector_index.set_fence(payload)
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
f"Index attempt not found: index_attempt={index_attempt_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
raise ValueError(
f"Index attempt not found: index_attempt={index_attempt_id}"
)
attempt_found = True
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
@@ -757,30 +603,25 @@ def connector_indexing_task(
)
if not cc_pair:
raise SimpleJobException(
f"cc_pair not found: cc_pair={cc_pair_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
)
raise ValueError(f"cc_pair not found: cc_pair={cc_pair_id}")
if not cc_pair.connector:
raise SimpleJobException(
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
raise ValueError(
f"Connector not found: cc_pair={cc_pair_id} connector={cc_pair.connector_id}"
)
if not cc_pair.credential:
raise SimpleJobException(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}",
code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code,
raise ValueError(
f"Credential not found: cc_pair={cc_pair_id} credential={cc_pair.credential_id}"
)
# define a callback class
callback = IndexingCallback(
os.getppid(),
redis_connector,
redis_connector_index,
lock,
r,
redis_connector_index,
)
logger.info(
@@ -802,15 +643,6 @@ def connector_indexing_task(
# get back the total number of indexed docs and return it
n_final_progress = redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
except ConnectorValidationError:
raise SimpleJobException(
f"Indexing task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
@@ -818,8 +650,22 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
if attempt_found:
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_failed(
index_attempt_id, db_session, failure_reason=str(e)
)
except Exception:
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
finally:
if lock.owned():
lock.release()
@@ -832,49 +678,41 @@ def connector_indexing_task(
return n_final_progress
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
result.connector_source = connector_source
def connector_indexing_task_wrapper(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
) -> int | None:
"""Just wraps connector_indexing_task so we can log any exceptions before
re-raising it."""
result: int | None = None
if job.process:
result.exit_code = job.process.exitcode
if job.status != "error":
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
return result
ignore_exitcode = False
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if ignore_exitcode:
result.status = IndexingWatchdogTerminalStatus.SUCCEEDED
task_logger.warning(
log_builder.build(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...",
exit_code=str(result.exit_code),
)
try:
result = connector_indexing_task(
index_attempt_id,
cc_pair_id,
search_settings_id,
tenant_id,
is_ee,
)
except Exception:
logger.exception(
f"connector_indexing_task exceptioned: "
f"tenant={tenant_id} "
f"index_attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
else:
if result.exit_code is not None:
result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code)
result.exception_str = job.exception()
# There is a cloud related bug outside of our code
# where spawned tasks return with an exit code of 1.
# Unfortunately, exceptions also return with an exit code of 1,
# so just raising an exception isn't informative
# Exiting with 255 makes it possible to distinguish between normal exits
# and exceptions.
sys.exit(255)
return result
@@ -890,336 +728,165 @@ def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
To work around this, we use pool=threads and proxy our work to a spawned task.
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
start = time.monotonic()
result = SimpleJobResult()
ctx = ConnectorIndexingContext(
tenant_id=tenant_id,
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
index_attempt_id=index_attempt_id,
)
log_builder = ConnectorIndexingLogBuilder(ctx)
"""celery tasks are forked, but forking is unstable. This proxies work to a spawned task."""
task_logger.info(
log_builder.build(
"Indexing watchdog - starting",
mp_start_method=str(multiprocessing.get_start_method()),
)
f"Indexing watchdog - starting: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"mp_start_method={multiprocessing.get_start_method()}"
)
if not self.request.id:
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting connector_indexing_task with tenant_id={tenant_id}")
job = client.submit(
connector_indexing_task,
connector_indexing_task_wrapper,
index_attempt_id,
cc_pair_id,
search_settings_id,
global_version.is_ee_version(),
tenant_id,
global_version.is_ee_version(),
pure=False,
)
if not job or not job.process:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
if not job:
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
f"Indexing watchdog - spawn failed: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
return
# Ensure the process has moved out of the starting state
num_waits = 0
while True:
if num_waits > 15:
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
job.release()
return
if job.process.is_alive() or job.process.exitcode is not None:
break
sleep(1)
num_waits += 1
task_logger.info(
log_builder.build(
"Indexing watchdog - spawn succeeded",
pid=str(job.process.pid),
)
f"Indexing watchdog - spawn succeeded: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
raise RuntimeError("Index attempt not found")
while True:
sleep(5)
result.connector_source = (
index_attempt.connector_credential_pair.connector.source.value
)
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
redis_connector_index.set_active() # renew active signal
redis_connector_index.set_connector_active() # prime the connective active signal
# renew active signal
redis_connector_index.set_active()
while True:
sleep(5)
# renew watchdog signal (this has a shorter timeout than set_active)
redis_connector_index.set_watchdog(True)
# renew active signal
redis_connector_index.set_active()
# if the job is done, clean up and break
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - spawned task exceptioned"
)
)
finally:
job.release()
break
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
)
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if activity timeout is detected, break (exit point will clean up)
if not redis_connector_index.connector_active():
task_logger.warning(
log_builder.build(
"Indexing watchdog - activity timeout exceeded",
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
)
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
# if the job is done, clean up and break
if job.done():
exit_code: int | None
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
if job.status == "error":
ignore_exitcode = False
exit_code = None
if job.process:
exit_code = job.process.exitcode
# seeing odd behavior where spawned tasks usually return exit code 1 in the cloud,
# even though logging clearly indicates successful completion
# to work around this, we ignore the job error state if the completion signal is OK
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
ignore_exitcode = True
if not ignore_exitcode:
raise RuntimeError("Spawned task exceptioned.")
task_logger.warning(
"Indexing watchdog - spawned task has non-zero exit code "
"but completion signal is OK. Continuing...: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code}"
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
# No need to expose full stack trace for validation errors
result.exception_str = str(e)
else:
result.exception_str = traceback.format_exc()
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_current_tenant() as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
normalized_exception_str = "None"
if result.exception_str:
normalized_exception_str = result.exception_str.replace(
"\n", "\\n"
).replace('"', '\\"')
task_logger.warning(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=result.status.value,
exit_code=str(result.exit_code),
exception=f'"{normalized_exception_str}"',
elapsed=f"{elapsed:.2f}s",
)
)
redis_connector_index.set_watchdog(False)
raise RuntimeError(f"Exception encountered: traceback={result.exception_str}")
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
)
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
task_logger.error(
"Indexing watchdog - spawned task exceptioned: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id} "
f"exit_code={exit_code} "
f"error={job.exception()}"
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
source=result.connector_source,
status=str(result.status.value),
exit_code=str(result.exit_code),
elapsed=f"{elapsed:.2f}s",
)
)
raise
finally:
job.release()
break
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
"Indexing watchdog - termination signal detected: "
f"attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
try:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
"Indexing watchdog - transient exception marking index attempt as canceled: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
job.cancel()
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
if not index_attempt:
continue
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
logger.exception(
"Indexing watchdog - transient exception looking up index attempt: "
f"attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
continue
redis_connector_index.set_watchdog(False)
return
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
)
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
"""Clean up old checkpoints that are older than 7 days."""
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
lock: RedisLock = redis_client.lock(
OnyxRedisLocks.CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
task_logger.info(
f"Indexing watchdog - finished: attempt={index_attempt_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
# these tasks should never overlap
if not lock.acquire(blocking=False):
return None
try:
locked = True
with get_session_with_current_tenant() as db_session:
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
for attempt in old_attempts:
task_logger.info(
f"Cleaning up checkpoint for index attempt {attempt.id}"
)
cleanup_checkpoint_task.apply_async(
kwargs={
"index_attempt_id": attempt.id,
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.CHECKPOINT_CLEANUP,
)
except Exception:
task_logger.exception("Unexpected exception during checkpoint cleanup")
return None
finally:
if locked:
if lock.owned():
lock.release()
else:
task_logger.error(
"check_for_checkpoint_cleanup - Lock not owned on completion: "
f"tenant={tenant_id}"
)
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,
)
def cleanup_checkpoint_task(
self: Task, *, index_attempt_id: int, tenant_id: str | None
) -> None:
"""Clean up a checkpoint for a given index attempt"""
with get_session_with_current_tenant() as db_session:
cleanup_checkpoint(db_session, index_attempt_id)
return

View File

@@ -23,7 +23,7 @@ from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
@@ -93,25 +93,27 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
return unfenced_attempts
class IndexingCallbackBase(IndexingHeartbeatInterface):
class IndexingCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_connector_index: RedisConnectorIndex,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_connector: RedisConnector = redis_connector
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
self.redis_lock: RedisLock = redis_lock
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = f"{self.__class__.__name__}.__init__"
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
@@ -125,8 +127,8 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon=True. It seems likely some indexing tasks will need to spawn other processes
# eventually, which daemon=True prevents, so leave this code in until we're ready to test it.
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
@@ -141,6 +143,8 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
# self.last_parent_check = now
try:
self.redis_connector.prune.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
@@ -152,7 +156,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_tag = tag
except LockError:
logger.exception(
f"{self.__class__.__name__} - lock.reacquire exceptioned: "
f"IndexingCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
@@ -163,31 +167,13 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
redis_lock_dump(self.redis_lock, self.redis_client)
raise
class IndexingCallback(IndexingCallbackBase):
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
redis_connector_index: RedisConnectorIndex,
):
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
def progress(self, tag: str, amount: int) -> None:
self.redis_connector_index.set_active()
self.redis_connector_index.set_connector_active()
super().progress(tag, amount)
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
def validate_indexing_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
@@ -254,8 +240,7 @@ def validate_indexing_fence(
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
f"validate_indexing_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
f"validate_indexing_fence - Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector_index.reset()
return
@@ -311,7 +296,7 @@ def validate_indexing_fence(
def validate_indexing_fences(
tenant_id: str,
tenant_id: str | None,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
@@ -332,7 +317,7 @@ def validate_indexing_fences(
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
continue
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
@@ -442,7 +427,7 @@ def try_creating_indexing_task(
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.

View File

@@ -8,7 +8,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import LLMProvider
@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
trail=False,
bind=True,
)
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
if not LLM_MODEL_UPDATE_API_URL:
raise ValueError("LLM model update API URL not configured")
@@ -75,7 +75,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
return None
# Then update the database with the fetched models
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# Get the default LLM provider
default_provider = (
db_session.query(LLMProvider)

View File

@@ -26,8 +26,7 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -43,6 +42,7 @@ from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
@@ -91,7 +91,7 @@ class Metric(BaseModel):
}
task_logger.info(json.dumps(data))
def emit(self, tenant_id: str) -> None:
def emit(self, tenant_id: str | None) -> None:
# Convert value to appropriate type based on the input value
bool_value = None
float_value = None
@@ -656,7 +656,7 @@ def build_job_id(
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
"""Collect and emit metrics about background processes.
This task runs periodically to gather metrics about:
- Queue lengths for different Celery queues
@@ -668,7 +668,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
task_logger.info("Starting background monitoring")
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_monitoring: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
@@ -683,7 +683,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client()
redis_std = get_redis_client(tenant_id=tenant_id)
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
@@ -693,7 +693,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
]
# Collect and log each metric
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
@@ -771,11 +771,12 @@ def cloud_check_alembic() -> bool | None:
if tenant_id is None:
continue
with get_session_with_shared_schema() as session:
with get_session_with_tenant(tenant_id=None) as session:
try:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
if result_scalar is None:
raise ValueError("Alembic version should not be None.")
@@ -864,7 +865,7 @@ def cloud_monitor_celery_queues(
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
return monitor_celery_queues_helper(self)

View File

@@ -15,7 +15,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PostgresAdvisoryLocks
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
@shared_task(
@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
@@ -36,7 +36,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),

View File

@@ -21,7 +21,7 @@ from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -41,7 +41,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -55,7 +55,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@@ -63,12 +62,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class PruneCallback(IndexingCallbackBase):
def progress(self, tag: str, amount: int) -> None:
self.redis_connector.prune.set_active()
super().progress(tag, amount)
"""Jobs / utils for kicking off pruning tasks."""
@@ -114,9 +107,9 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -134,14 +127,14 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -189,20 +182,18 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(f"Unexpected pruning check exception: {error_msg}")
except Exception:
task_logger.exception("Unexpected exception during pruning check")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_pruning finished: tenant={tenant_id}")
return True
@@ -211,7 +202,7 @@ def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
@@ -304,19 +295,13 @@ def try_creating_prune_generator_task(
redis_connector.prune.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_prune_generator_task exception: cc_pair={cc_pair.id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_prune_generator_task finished: cc_pair={cc_pair.id} payload_id={payload_id}"
)
return payload_id
@@ -333,7 +318,7 @@ def connector_pruning_generator_task(
cc_pair_id: int,
connector_id: int,
credential_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -352,7 +337,7 @@ def connector_pruning_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -410,7 +395,7 @@ def connector_pruning_generator_task(
return None
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
@@ -440,7 +425,6 @@ def connector_pruning_generator_task(
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -450,11 +434,12 @@ def connector_pruning_generator_task(
)
search_settings = get_current_search_settings(db_session)
redis_connector.new_index(search_settings.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
callback = PruneCallback(
callback = IndexingCallback(
0,
redis_connector,
redis_connector_index,
lock,
r,
)
@@ -521,7 +506,7 @@ def connector_pruning_generator_task(
def monitor_ccpair_pruning_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -567,7 +552,7 @@ def monitor_ccpair_pruning_taskset(
def validate_pruning_fences(
tenant_id: str,
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -615,7 +600,7 @@ def validate_pruning_fences(
def validate_pruning_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
queued_tasks: set[str],

View File

@@ -32,7 +32,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
) -> int:
return self.index.delete_single(
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:

View File

@@ -1,5 +1,4 @@
import time
from enum import Enum
from http import HTTPStatus
import httpx
@@ -28,7 +27,7 @@ from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
@@ -46,24 +45,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
class OnyxCeleryTaskCompletionStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SKIPPED = "skipped"
SOFT_TIME_LIMIT = "soft_time_limit"
NON_RETRYABLE_EXCEPTION = "non_retryable_exception"
RETRYABLE_EXCEPTION = "retryable_exception"
@shared_task(
name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
@@ -76,7 +57,7 @@ def document_by_cc_pair_cleanup_task(
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str,
tenant_id: str | None,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
@@ -97,10 +78,8 @@ def document_by_cc_pair_cleanup_task(
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
chunks_affected = 0
@@ -126,14 +105,10 @@ def document_by_cc_pair_cleanup_task(
tenant_id=tenant_id,
chunk_count=chunk_count,
)
delete_documents_complete__no_commit(
db_session=db_session,
document_ids=[document_id],
)
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
elif count > 1:
action = "update"
@@ -177,11 +152,10 @@ def document_by_cc_pair_cleanup_task(
)
mark_document_as_synced(document_id, db_session)
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
else:
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
pass
db_session.commit()
elapsed = time.monotonic() - start
task_logger.info(
@@ -193,79 +167,57 @@ def document_by_cc_pair_cleanup_task(
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
return False
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"document_by_cc_pair_cleanup_task exceptioned: doc={document_id}"
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(f"Unexpected exception: doc={document_id}")
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"document_by_cc_pair_cleanup_task completed: status={completion_status.value} doc={document_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
self.retry(exc=e, countdown=countdown)
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
return False
task_logger.info(f"document_by_cc_pair_cleanup_task finished: doc={document_id}")
return True
@@ -297,8 +249,7 @@ def cloud_beat_task_generator(
return None
last_lock_time = time.monotonic()
tenant_ids: list[str] = []
num_processed_tenants = 0
tenant_ids: list[str] | list[None] = []
try:
tenant_ids = get_all_tenant_ids()
@@ -326,8 +277,6 @@ def cloud_beat_task_generator(
expires=expires,
ignore_result=True,
)
num_processed_tenants += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -347,7 +296,6 @@ def cloud_beat_task_generator(
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)

View File

@@ -19,7 +19,6 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -35,7 +34,7 @@ from onyx.db.document_set import fetch_document_sets
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import DocumentSet
@@ -76,7 +75,7 @@ logger = setup_logger()
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
@@ -85,8 +84,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
time_start = time.monotonic()
r = get_redis_client()
r_replica = get_redis_replica_client()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -99,7 +98,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
try:
# 1/3: KICKOFF
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
)
@@ -107,7 +106,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
# region document set scan
lock_beat.reacquire()
document_set_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
@@ -118,7 +117,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
for document_set_id in document_set_ids:
lock_beat.reacquire()
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_document_set_sync_tasks(
self.app, document_set_id, db_session, r, lock_beat, tenant_id
)
@@ -137,7 +136,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
pass
else:
usergroup_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
@@ -147,7 +146,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
for usergroup_id in usergroup_ids:
lock_beat.reacquire()
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_user_group_sync_tasks(
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
)
@@ -168,7 +167,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
monitor_usergroup_taskset = (
@@ -178,7 +177,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
noop_fallback,
)
)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
@@ -208,7 +207,7 @@ def try_generate_stale_document_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -284,7 +283,7 @@ def try_generate_document_set_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -361,7 +360,7 @@ def try_generate_user_group_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -448,7 +447,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -523,13 +522,13 @@ def monitor_document_set_taskset(
time_limit=LIGHT_TIME_LIMIT,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
@@ -541,103 +540,75 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
doc = get_document(document_id, db_session)
if not doc:
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=no_operation "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
else:
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
return False
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"vespa_metadata_sync_task exceptioned: doc={document_id}"
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"vespa_metadata_sync_task completed: status={completion_status.value} doc={document_id}"
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(
f"Unexpected exception during vespa metadata sync: doc={document_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -1,7 +1,5 @@
from sqlalchemy.exc import IntegrityError
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
def emit_background_error(
@@ -11,10 +9,5 @@ def emit_background_error(
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_current_tenant() as db_session:
try:
create_background_error(db_session, message, cc_pair_id)
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
create_background_error(db_session, error_message, None)
with get_session_with_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)

View File

@@ -0,0 +1,80 @@
"""Experimental functionality related to splitting up indexing
into a series of checkpoints to better handle intermittent failures
/ jobs being killed by cloud providers."""
import datetime
from onyx.configs.app_configs import EXPERIMENTAL_CHECKPOINTING_ENABLED
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc
def _2010_dt() -> datetime.datetime:
return datetime.datetime(year=2010, month=1, day=1, tzinfo=datetime.timezone.utc)
def _2020_dt() -> datetime.datetime:
return datetime.datetime(year=2020, month=1, day=1, tzinfo=datetime.timezone.utc)
def _default_end_time(
last_successful_run: datetime.datetime | None,
) -> datetime.datetime:
"""If year is before 2010, go to the beginning of 2010.
If year is 2010-2020, go in 5 year increments.
If year > 2020, then go in 180 day increments.
For connectors that don't support a `filter_by` and instead rely on `sort_by`
for polling, then this will cause a massive duplication of fetches. For these
connectors, you may want to override this function to return a more reasonable
plan (e.g. extending the 2020+ windows to 6 months, 1 year, or higher)."""
last_successful_run = (
datetime_to_utc(last_successful_run) if last_successful_run else None
)
if last_successful_run is None or last_successful_run < _2010_dt():
return _2010_dt()
if last_successful_run < _2020_dt():
return min(last_successful_run + datetime.timedelta(days=365 * 5), _2020_dt())
return last_successful_run + datetime.timedelta(days=180)
def find_end_time_for_indexing_attempt(
last_successful_run: datetime.datetime | None,
# source_type can be used to override the default for certain connectors, currently unused
source_type: DocumentSource,
) -> datetime.datetime | None:
"""Is the current time unless the connector is run over a large period, in which case it is
split up into large time segments that become smaller as it approaches the present
"""
# NOTE: source_type can be used to override the default for certain connectors
end_of_window = _default_end_time(last_successful_run)
now = datetime.datetime.now(tz=datetime.timezone.utc)
if end_of_window < now:
return end_of_window
# None signals that we should index up to current time
return None
def get_time_windows_for_index_attempt(
last_successful_run: datetime.datetime, source_type: DocumentSource
) -> list[tuple[datetime.datetime, datetime.datetime]]:
if not EXPERIMENTAL_CHECKPOINTING_ENABLED:
return [(last_successful_run, datetime.datetime.now(tz=datetime.timezone.utc))]
time_windows: list[tuple[datetime.datetime, datetime.datetime]] = []
start_of_window: datetime.datetime | None = last_successful_run
while start_of_window:
end_of_window = find_end_time_for_indexing_attempt(
last_successful_run=start_of_window, source_type=source_type
)
time_windows.append(
(
start_of_window,
end_of_window or datetime.datetime.now(tz=datetime.timezone.utc),
)
)
start_of_window = end_of_window
return time_windows

View File

@@ -1,200 +0,0 @@
from datetime import datetime
from datetime import timedelta
from io import BytesIO
from sqlalchemy import and_
from sqlalchemy.orm import Session
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import ConnectorCheckpoint
from onyx.db.engine import get_db_current_time
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from onyx.utils.object_size_check import deep_getsizeof
logger = setup_logger()
_NUM_RECENT_ATTEMPTS_TO_CONSIDER = 20
_NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT = 100
def _build_checkpoint_pointer(index_attempt_id: int) -> str:
return f"checkpoint_{index_attempt_id}.json"
def save_checkpoint(
db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint
) -> str:
"""Save a checkpoint for a given index attempt to the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=checkpoint_pointer,
content=BytesIO(checkpoint.model_dump_json().encode()),
display_name=checkpoint_pointer,
file_origin=FileOrigin.INDEXING_CHECKPOINT,
file_type="application/json",
)
index_attempt = get_index_attempt(db_session, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
index_attempt.checkpoint_pointer = checkpoint_pointer
db_session.add(index_attempt)
db_session.commit()
return checkpoint_pointer
def load_checkpoint(
db_session: Session, index_attempt_id: int
) -> ConnectorCheckpoint | None:
"""Load a checkpoint for a given index attempt from the file store"""
checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id)
file_store = get_default_file_store(db_session)
try:
checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb")
checkpoint_data = checkpoint_io.read().decode("utf-8")
return ConnectorCheckpoint.model_validate_json(checkpoint_data)
except RuntimeError:
return None
def get_latest_valid_checkpoint(
db_session: Session,
cc_pair_id: int,
search_settings_id: int,
window_start: datetime,
window_end: datetime,
) -> ConnectorCheckpoint:
"""Get the latest valid checkpoint for a given connector credential pair"""
checkpoint_candidates = get_recent_completed_attempts_for_cc_pair(
cc_pair_id=cc_pair_id,
search_settings_id=search_settings_id,
db_session=db_session,
limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER,
)
checkpoint_candidates = [
candidate
for candidate in checkpoint_candidates
if (
candidate.poll_range_start == window_start
and candidate.poll_range_end == window_end
and candidate.status == IndexingStatus.FAILED
and candidate.checkpoint_pointer is not None
# we want to make sure that the checkpoint is actually useful
# if it's only gone through a few docs, it's probably not worth
# using. This also avoids weird cases where a connector is basically
# non-functional but still "makes progress" by slowly moving the
# checkpoint forward run after run
and candidate.total_docs_indexed
and candidate.total_docs_indexed > _NUM_DOCS_INDEXED_TO_BE_VALID_CHECKPOINT
)
]
# don't keep using checkpoints if we've had a bunch of failed attempts in a row
# for now, capped at 10
if len(checkpoint_candidates) == _NUM_RECENT_ATTEMPTS_TO_CONSIDER:
logger.warning(
f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts found "
f"for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start "
"from scratch."
)
return ConnectorCheckpoint.build_dummy_checkpoint()
# assumes latest checkpoint is the furthest along. This only isn't true
# if something else has gone wrong.
latest_valid_checkpoint_candidate = (
checkpoint_candidates[0] if checkpoint_candidates else None
)
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
if latest_valid_checkpoint_candidate:
try:
previous_checkpoint = load_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
)
except Exception:
logger.exception(
f"Failed to load checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}."
)
previous_checkpoint = None
if previous_checkpoint is not None:
logger.info(
f"Using checkpoint from previous failed attempt with ID "
f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: "
f"{previous_checkpoint}"
)
save_checkpoint(
db_session=db_session,
index_attempt_id=latest_valid_checkpoint_candidate.id,
checkpoint=previous_checkpoint,
)
checkpoint = previous_checkpoint
return checkpoint
def get_index_attempts_with_old_checkpoints(
db_session: Session, days_to_keep: int = 7
) -> list[IndexAttempt]:
"""Get all index attempts with checkpoints older than the specified number of days.
Args:
db_session: The database session
days_to_keep: Number of days to keep checkpoints for (default: 7)
Returns:
Number of checkpoints deleted
"""
cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep)
# Find all index attempts with checkpoints older than cutoff_date
old_attempts = (
db_session.query(IndexAttempt)
.filter(
and_(
IndexAttempt.checkpoint_pointer.isnot(None),
IndexAttempt.time_created < cutoff_date,
)
)
.all()
)
return old_attempts
def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None:
"""Clean up a checkpoint for a given index attempt"""
index_attempt = get_index_attempt(db_session, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if not index_attempt.checkpoint_pointer:
return None
file_store = get_default_file_store(db_session)
file_store.delete_file(index_attempt.checkpoint_pointer)
index_attempt.checkpoint_pointer = None
db_session.add(index_attempt)
db_session.commit()
return None
def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None:
"""Check if the checkpoint content size exceeds the limit (200MB)"""
content_size = deep_getsizeof(checkpoint.checkpoint_content)
if content_size > 200_000_000: # 200MB in bytes
raise ValueError(
f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit"
)

View File

@@ -5,8 +5,6 @@ not follow the expected behavior, etc.
NOTE: cannot use Celery directly due to
https://github.com/celery/celery/issues/7007#issuecomment-1740139367"""
import multiprocessing as mp
import sys
import traceback
from collections.abc import Callable
from dataclasses import dataclass
from multiprocessing.context import SpawnProcess
@@ -16,23 +14,10 @@ from typing import Optional
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.setup import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from onyx.utils.logger import setup_logger
logger = setup_logger()
class SimpleJobException(Exception):
"""lets us raise an exception that will return a specific error code"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
code: int | None = kwargs.pop("code", None)
self.code = code
super().__init__(*args, **kwargs)
JobStatusType = (
Literal["error"]
| Literal["finished"]
@@ -43,10 +28,7 @@ JobStatusType = (
def _initializer(
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> Any:
"""Initialize the child process with a fresh SQLAlchemy Engine.
@@ -57,15 +39,6 @@ def _initializer(
kwargs = {}
logger.info("Initializing spawned worker child process.")
# 1. Get tenant_id from args or fallback to default
tenant_id = POSTGRES_DEFAULT_SCHEMA
for arg in reversed(args):
if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX):
tenant_id = arg
break
# 2. Set the tenant context before running anything
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Reset the engine in the child process
SqlEngine.reset_engine()
@@ -79,31 +52,13 @@ def _initializer(
)
# Proceed with executing the target function
try:
return func(*args, **kwargs)
except SimpleJobException as e:
logger.exception("SimpleJob raised a SimpleJobException")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(e.code) # use the given exit code
except Exception:
logger.exception("SimpleJob raised an exception")
error_msg = traceback.format_exc()
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use 255 to indicate a generic exception
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return func(*args, **kwargs)
def _run_in_process(
func: Callable,
queue: mp.Queue,
args: list | tuple,
kwargs: dict[str, Any] | None = None,
func: Callable, args: list | tuple, kwargs: dict[str, Any] | None = None
) -> None:
_initializer(func, queue, args, kwargs)
_initializer(func, args, kwargs)
@dataclass
@@ -112,8 +67,6 @@ class SimpleJob:
id: int
process: Optional["SpawnProcess"] = None
queue: Optional[mp.Queue] = None
_exception: Optional[str] = None
def cancel(self) -> bool:
return self.release()
@@ -147,15 +100,9 @@ class SimpleJob:
def exception(self) -> str:
"""Needed to match the Dask API, but not implemented since we don't currently
have a way to get back the exception information from the child process."""
"""Retrieve exception from the multiprocessing queue if available."""
if self._exception is None and self.queue and not self.queue.empty():
self._exception = self.queue.get() # Get exception from queue
if self._exception:
return self._exception
return f"Job with ID '{self.id}' did not report an exception."
return (
f"Job with ID '{self.id}' was killed or encountered an unhandled exception."
)
class SimpleJobClient:
@@ -190,11 +137,8 @@ class SimpleJobClient:
# this approach allows us to always "spawn" a new process regardless of
# get_start_method's current setting
ctx = mp.get_context("spawn")
queue = ctx.Queue()
process = ctx.Process(
target=_run_in_process, args=(func, queue, args), daemon=True
)
job = SimpleJob(id=job_id, process=process, queue=queue)
process = ctx.Process(target=_run_in_process, args=(func, args), daemon=True)
job = SimpleJob(id=job_id, process=process)
process.start()
self.jobs[job_id] = job

View File

@@ -1,87 +0,0 @@
import tracemalloc
from onyx.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class MemoryTracer:
def __init__(self, interval: int = 0, num_print_entries: int = 5):
self.interval = interval
self.num_print_entries = num_print_entries
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
self.counter = 0
def start(self) -> None:
"""Start the memory tracer if interval is greater than 0."""
if self.interval > 0:
logger.debug(f"Memory tracer starting: interval={self.interval}")
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
self._take_snapshot()
def stop(self) -> None:
"""Stop the memory tracer if it's running."""
if self.interval > 0:
self.log_final_diff()
tracemalloc.stop()
logger.debug("Memory tracer stopped.")
def _take_snapshot(self) -> None:
"""Take a snapshot and update internal snapshot states."""
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__),
tracemalloc.Filter(False, "<frozen importlib._bootstrap>"),
tracemalloc.Filter(False, "<frozen importlib._bootstrap_external>"),
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def _log_diff(
self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot
) -> None:
"""Log the memory difference between two snapshots."""
stats = current.compare_to(previous, "traceback")
for s in stats[: self.num_print_entries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def increment_and_maybe_trace(self) -> None:
"""Increment counter and perform trace if interval is hit."""
if self.interval <= 0:
return
self.counter += 1
if self.counter % self.interval == 0:
logger.debug(
f"Running trace comparison for batch {self.counter}. interval={self.interval}"
)
self._take_snapshot()
if self.snapshot and self.snapshot_prev:
self._log_diff(self.snapshot, self.snapshot_prev)
def log_final_diff(self) -> None:
"""Log the final memory diff between start and end of indexing."""
if self.interval <= 0:
return
logger.debug(
f"Running trace comparison between start and end of indexing. {self.counter} batches processed."
)
self._take_snapshot()
if self.snapshot and self.snapshot_first:
self._log_diff(self.snapshot, self.snapshot_first)

View File

@@ -1,40 +0,0 @@
from datetime import datetime
from pydantic import BaseModel
from onyx.db.models import IndexAttemptError
class IndexAttemptErrorPydantic(BaseModel):
id: int
connector_credential_pair_id: int
document_id: str | None
document_link: str | None
entity_id: str | None
failed_time_range_start: datetime | None
failed_time_range_end: datetime | None
failure_message: str
is_resolved: bool = False
time_created: datetime
index_attempt_id: int
@classmethod
def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic":
return cls(
id=model.id,
connector_credential_pair_id=model.connector_credential_pair_id,
document_id=model.document_id,
document_link=model.document_link,
entity_id=model.entity_id,
failed_time_range_start=model.failed_time_range_start,
failed_time_range_end=model.failed_time_range_end,
failure_message=model.failure_message,
is_resolved=model.is_resolved,
time_created=model.time_created,
index_attempt_id=model.index_attempt_id,
)

View File

@@ -1,6 +1,5 @@
import time
import traceback
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -8,42 +7,32 @@ from datetime import timezone
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.background.indexing.checkpointing_utils import check_checkpoint_size
from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint
from onyx.background.indexing.checkpointing_utils import save_checkpoint
from onyx.background.indexing.memory_tracer import MemoryTracer
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.background.indexing.checkpointing import get_time_windows_for_index_attempt
from onyx.background.indexing.tracer import OnyxTracer
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair
from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair
from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.index_attempt import transition_attempt_to_in_progress
from onyx.db.index_attempt import update_docs_indexed
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
from onyx.document_index.factory import get_default_document_index
@@ -55,7 +44,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -65,9 +53,9 @@ INDEXING_TRACER_NUM_PRINT_ENTRIES = 5
def _get_connector_runner(
db_session: Session,
attempt: IndexAttempt,
batch_size: int,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
) -> ConnectorRunner:
"""
@@ -86,12 +74,8 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
# validate the connector settings
if not INTEGRATION_TESTS_MODE:
runnable_connector.validate_connector_settings()
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
@@ -116,9 +100,7 @@ def _get_connector_runner(
raise e
return ConnectorRunner(
connector=runnable_connector,
batch_size=batch_size,
time_range=(start_time, end_time),
connector=runnable_connector, time_range=(start_time, end_time)
)
@@ -177,80 +159,23 @@ class RunIndexingContext(BaseModel):
search_settings_status: IndexModelStatus
def _check_connector_and_attempt_status(
db_session_temp: Session, ctx: RunIndexingContext, index_attempt_id: int
) -> None:
"""
Checks the status of the connector credential pair and index attempt.
Raises a RuntimeError if any conditions are not met.
"""
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
if (
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING:
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_loop:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
def _check_failure_threshold(
total_failures: int,
document_count: int,
batch_num: int,
last_failure: ConnectorFailure | None,
) -> None:
"""Check if we've hit the failure threshold and raise an appropriate exception if so.
We consider the threshold hit if:
1. We have more than 3 failures AND
2. Failures account for more than 10% of processed documents
"""
failure_ratio = total_failures / (document_count or 1)
FAILURE_THRESHOLD = 3
FAILURE_RATIO_THRESHOLD = 0.1
if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD:
logger.error(
f"Connector run failed with '{total_failures}' errors "
f"after '{batch_num}' batches."
)
if last_failure and last_failure.exception:
raise last_failure.exception from last_failure.exception
raise RuntimeError(
f"Connector run encountered too many errors, aborting. "
f"Last error: {last_failure}"
)
def _run_indexing(
db_session: Session,
index_attempt_id: int,
tenant_id: str,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
1. Get documents which are either new or updated from specified application
2. Embed and index these documents into the chosen datastore (vespa)
3. Updates Postgres to record the indexed documents + the outcome of this run
"""
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
TODO: do not change index attempt statuses here ... instead, set signals in redis
and allow the monitor function to clean them up
"""
start_time = time.time()
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_start:
raise ValueError(
@@ -296,46 +221,6 @@ def _run_indexing(
db_session=db_session_temp,
)
)
if last_successful_index_time > POLL_CONNECTOR_OFFSET:
window_start = datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
) - timedelta(minutes=POLL_CONNECTOR_OFFSET)
else:
# don't go into "negative" time if we've never indexed before
window_start = datetime.fromtimestamp(0, tz=timezone.utc)
most_recent_attempt = next(
iter(
get_recent_completed_attempts_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt_start.search_settings_id,
db_session=db_session_temp,
limit=1,
)
),
None,
)
# if the last attempt failed, try and use the same window. This is necessary
# to ensure correctness with checkpointing. If we don't do this, things like
# new slack channels could be missed (since existing slack channels are
# cached as part of the checkpoint).
if (
most_recent_attempt
and most_recent_attempt.poll_range_end
and (
most_recent_attempt.status == IndexingStatus.FAILED
or most_recent_attempt.status == IndexingStatus.CANCELED
)
):
window_end = most_recent_attempt.poll_range_end
else:
window_end = datetime.now(tz=timezone.utc)
# add start/end now that they have been set
index_attempt_start.poll_range_start = window_start
index_attempt_start.poll_range_end = window_end
db_session_temp.add(index_attempt_start)
db_session_temp.commit()
embedding_model = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=index_attempt_start.search_settings,
@@ -349,6 +234,7 @@ def _run_indexing(
)
indexing_pipeline = build_indexing_pipeline(
attempt_id=index_attempt_id,
embedder=embedding_model,
document_index=document_index,
ignore_time_skip=(
@@ -360,72 +246,63 @@ def _run_indexing(
callback=callback,
)
# Initialize memory tracer. NOTE: won't actually do anything if
# `INDEXING_TRACER_INTERVAL` is 0.
memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL)
memory_tracer.start()
tracer: OnyxTracer
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(f"Memory tracer starting: interval={INDEXING_TRACER_INTERVAL}")
tracer = OnyxTracer()
tracer.start()
tracer.snap()
index_attempt_md = IndexAttemptMetadata(
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
total_failures = 0
batch_num = 0
net_doc_change = 0
document_count = 0
chunk_count = 0
try:
with get_session_with_current_tenant() as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
run_end_dt = None
tracer_counter: int
connector_runner = _get_connector_runner(
db_session=db_session_temp,
attempt=index_attempt,
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
for ind, (window_start, window_end) in enumerate(
get_time_windows_for_index_attempt(
last_successful_run=datetime.fromtimestamp(
last_successful_index_time, tz=timezone.utc
),
source_type=db_connector.source,
)
):
cc_pair_loop: ConnectorCredentialPair | None = None
index_attempt_loop: IndexAttempt | None = None
tracer_counter = 0
try:
window_start = max(
window_start - timedelta(minutes=POLL_CONNECTOR_OFFSET),
datetime(1970, 1, 1, tzinfo=timezone.utc),
)
# don't use a checkpoint if we're explicitly indexing from
# the beginning in order to avoid weird interactions between
# checkpointing / failure handling.
if index_attempt.from_beginning:
checkpoint = ConnectorCheckpoint.build_dummy_checkpoint()
else:
checkpoint = get_latest_valid_checkpoint(
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_loop_start = get_index_attempt(
db_session_temp, index_attempt_id
)
if not index_attempt_loop_start:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
connector_runner = _get_connector_runner(
db_session=db_session_temp,
cc_pair_id=ctx.cc_pair_id,
search_settings_id=index_attempt.search_settings_id,
window_start=window_start,
window_end=window_end,
attempt=index_attempt_loop_start,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
unresolved_errors = get_index_attempt_errors_for_cc_pair(
cc_pair_id=ctx.cc_pair_id,
unresolved_only=True,
db_session=db_session_temp,
)
doc_id_to_unresolved_errors: dict[
str, list[IndexAttemptError]
] = defaultdict(list)
for error in unresolved_errors:
if error.document_id:
doc_id_to_unresolved_errors[error.document_id].append(error)
entity_based_unresolved_errors = [
error for error in unresolved_errors if error.entity_id
]
while checkpoint.has_more:
logger.info(
f"Running '{ctx.source}' connector with checkpoint: {checkpoint}"
)
for document_batch, failure, next_checkpoint in connector_runner.run(
checkpoint
):
if INDEXING_TRACER_INTERVAL > 0:
tracer.snap()
for doc_batch in connector_runner.run():
# Check if connector is disabled mid run and stop if so unless it's the secondary
# index being built. We want to populate it even for paused connectors
# Often paused connectors are sources that aren't updated frequently but the
@@ -435,38 +312,42 @@ def _run_indexing(
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
with get_session_with_current_tenant() as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
with get_session_with_tenant(tenant_id) as db_session_temp:
cc_pair_loop = get_connector_credential_pair_from_id(
db_session_temp,
ctx.cc_pair_id,
)
if not cc_pair_loop:
raise RuntimeError(f"CC pair {ctx.cc_pair_id} not found in DB.")
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session_temp:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
if (
(
cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED
and ctx.search_settings_status != IndexModelStatus.FUTURE
)
# if it's deleting, we don't care if this is a secondary index
or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING
):
# let the `except` block handle this
raise RuntimeError("Connector was disabled mid run")
index_attempt_loop = get_index_attempt(
db_session_temp, index_attempt_id
)
if not index_attempt_loop:
raise RuntimeError(
f"Index attempt {index_attempt_id} not found in DB."
)
_check_failure_threshold(
total_failures, document_count, batch_num, failure
)
# save the new checkpoint (if one is provided)
if next_checkpoint:
checkpoint = next_checkpoint
# below is all document processing logic, so if no batch we can just continue
if document_batch is None:
continue
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
# Likely due to user manually disabling it or model swap
raise RuntimeError(
f"Index Attempt was canceled, status is {index_attempt_loop.status}"
)
batch_description = []
doc_batch_cleaned = strip_null_characters(document_batch)
doc_batch_cleaned = strip_null_characters(doc_batch)
for doc in doc_batch_cleaned:
batch_description.append(doc.to_short_descriptor())
@@ -496,51 +377,15 @@ def _run_indexing(
chunk_count += index_pipeline_result.total_chunks
document_count += index_pipeline_result.total_docs
# resolve errors for documents that were successfully indexed
failed_document_ids = [
failure.failed_document.document_id
for failure in index_pipeline_result.failures
if failure.failed_document
]
successful_document_ids = [
document.id
for document in document_batch
if document.id not in failed_document_ids
]
for document_id in successful_document_ids:
with get_session_with_current_tenant() as db_session_temp:
if document_id in doc_id_to_unresolved_errors:
logger.info(
f"Resolving IndexAttemptError for document '{document_id}'"
)
for error in doc_id_to_unresolved_errors[document_id]:
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
# add brand new failures
if index_pipeline_result.failures:
total_failures += len(index_pipeline_result.failures)
with get_session_with_current_tenant() as db_session_temp:
for failure in index_pipeline_result.failures:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
failure,
db_session_temp,
)
_check_failure_threshold(
total_failures,
document_count,
batch_num,
index_pipeline_result.failures[-1],
)
# commit transaction so that the `update` below begins
# with a brand new transaction. Postgres uses the start
# of the transactions when computing `NOW()`, so if we have
# a long running transaction, the `time_updated` field will
# be inaccurate
db_session.commit()
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_current_tenant() as db_session_temp:
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
# so we need either to commit() or to use a new session
with get_session_with_tenant(tenant_id) as db_session_temp:
update_docs_indexed(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
@@ -552,97 +397,126 @@ def _run_indexing(
if callback:
callback.progress("_run_indexing", len(doc_batch_cleaned))
memory_tracer.increment_and_maybe_trace()
tracer_counter += 1
if (
INDEXING_TRACER_INTERVAL > 0
and tracer_counter % INDEXING_TRACER_INTERVAL == 0
):
logger.debug(
f"Running trace comparison for batch {tracer_counter}. interval={INDEXING_TRACER_INTERVAL}"
)
tracer.snap()
tracer.log_previous_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
# `make sure the checkpoints aren't getting too large`at some regular interval
CHECKPOINT_SIZE_CHECK_INTERVAL = 100
if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0:
check_checkpoint_size(checkpoint)
run_end_dt = window_end
if ctx.is_primary:
with get_session_with_tenant(tenant_id) as db_session_temp:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
run_dt=run_end_dt,
)
except Exception as e:
logger.exception(
f"Connector run exceptioned after elapsed time: {time.time() - start_time} seconds"
)
# save latest checkpoint
with get_session_with_current_tenant() as db_session_temp:
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
checkpoint=checkpoint,
)
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
except Exception as e:
logger.exception(
"Connector run exceptioned after elapsed time: "
f"{time.monotonic() - start_time} seconds"
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
else:
# Only mark the attempt as a complete failure if this is the first indexing window.
# Otherwise, some progress was made - the next run will not start from the beginning.
# In this case, it is not accurate to mark it as a failure. When the next run begins,
# if that fails immediately, it will be marked as a failure.
#
# NOTE: if the connector is manually disabled, we should mark it as a failure regardless
# to give better clarity in the UI, as the next run will never happen.
if (
ind == 0
or (
cc_pair_loop is not None and not cc_pair_loop.status.is_active()
)
or (
index_attempt_loop is not None
and index_attempt_loop.status != IndexingStatus.IN_PROGRESS
)
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
if INDEXING_TRACER_INTERVAL > 0:
tracer.stop()
raise e
# break => similar to success case. As mentioned above, if the next run fails for the same
# reason it will then be marked as a failure
break
if INDEXING_TRACER_INTERVAL > 0:
logger.debug(
f"Running trace comparison between start and end of indexing. {tracer_counter} batches processed."
)
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
tracer.snap()
tracer.log_first_diff(INDEXING_TRACER_NUM_PRINT_ENTRIES)
tracer.stop()
logger.debug("Memory tracer stopped.")
if (
index_attempt_md.num_exceptions > 0
and index_attempt_md.num_exceptions >= batch_num
):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason="All batches exceptioned.",
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
)
raise Exception(
f"Connector failed - All batches exceptioned: batches={batch_num}"
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
memory_tracer.stop()
raise e
elapsed_time = time.time() - start_time
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
memory_tracer.stop()
raise e
else:
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
failure_reason=str(e),
full_exception_trace=traceback.format_exc(),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
net_docs=net_doc_change,
)
memory_tracer.stop()
raise e
memory_tracer.stop()
elapsed_time = time.monotonic() - start_time
with get_session_with_current_tenant() as db_session_temp:
# resolve entity-based errors
for error in entity_based_unresolved_errors:
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
error.is_resolved = True
db_session_temp.add(error)
db_session_temp.commit()
if total_failures == 0:
with get_session_with_tenant(tenant_id) as db_session_temp:
if index_attempt_md.num_exceptions == 0:
mark_attempt_succeeded(index_attempt_id, db_session_temp)
create_milestone_and_report(
@@ -661,7 +535,7 @@ def _run_indexing(
mark_attempt_partially_succeeded(index_attempt_id, db_session_temp)
logger.info(
f"Connector completed with some errors: "
f"failures={total_failures} "
f"exceptions={index_attempt_md.num_exceptions} "
f"batches={batch_num} "
f"docs={document_count} "
f"chunks={chunk_count} "
@@ -673,54 +547,57 @@ def _run_indexing(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
run_dt=window_end,
run_dt=run_end_dt,
)
def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""Don't swallow exceptions here ... propagate them up."""
try:
if is_ee:
global_version.set_ee()
if is_ee:
global_version.set_ee()
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_current_tenant() as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if MULTI_TENANT:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
# set the indexing attempt ID so that all log messages from this process
# will have it added as a prefix
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
credential_id = attempt.connector_credential_pair.credential_id
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
tenant_str = ""
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
with get_session_with_current_tenant() as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
connector_name = attempt.connector_credential_pair.connector.name
connector_config = (
attempt.connector_credential_pair.connector.connector_specific_config
)
credential_id = attempt.connector_credential_pair.credential_id
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
logger.info(
f"Indexing starting{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(
f"Indexing finished{tenant_str}: "
f"connector='{connector_name}' "
f"config='{connector_config}' "
f"credentials='{credential_id}'"
)
except Exception as e:
logger.exception(
f"Indexing job with ID '{index_attempt_id}' for tenant {tenant_id} failed due to {e}"
)

View File

@@ -0,0 +1,77 @@
import tracemalloc
from onyx.utils.logger import setup_logger
logger = setup_logger()
DANSWER_TRACEMALLOC_FRAMES = 10
class OnyxTracer:
def __init__(self) -> None:
self.snapshot_first: tracemalloc.Snapshot | None = None
self.snapshot_prev: tracemalloc.Snapshot | None = None
self.snapshot: tracemalloc.Snapshot | None = None
def start(self) -> None:
tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES)
def stop(self) -> None:
tracemalloc.stop()
def snap(self) -> None:
snapshot = tracemalloc.take_snapshot()
# Filter out irrelevant frames (e.g., from tracemalloc itself or importlib)
snapshot = snapshot.filter_traces(
(
tracemalloc.Filter(False, tracemalloc.__file__), # Exclude tracemalloc
tracemalloc.Filter(
False, "<frozen importlib._bootstrap>"
), # Exclude importlib
tracemalloc.Filter(
False, "<frozen importlib._bootstrap_external>"
), # Exclude external importlib
)
)
if not self.snapshot_first:
self.snapshot_first = snapshot
if self.snapshot:
self.snapshot_prev = self.snapshot
self.snapshot = snapshot
def log_snapshot(self, numEntries: int) -> None:
if not self.snapshot:
return
stats = self.snapshot.statistics("traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer snap: {s}")
for line in s.traceback:
logger.debug(f"* {line}")
@staticmethod
def log_diff(
snap_current: tracemalloc.Snapshot,
snap_previous: tracemalloc.Snapshot,
numEntries: int,
) -> None:
stats = snap_current.compare_to(snap_previous, "traceback")
for s in stats[:numEntries]:
logger.debug(f"Tracer diff: {s}")
for line in s.traceback.format():
logger.debug(f"* {line}")
def log_previous_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_prev:
return
OnyxTracer.log_diff(self.snapshot, self.snapshot_prev, numEntries)
def log_first_diff(self, numEntries: int) -> None:
if not self.snapshot or not self.snapshot_first:
return
OnyxTracer.log_diff(self.snapshot, self.snapshot_first, numEntries)

View File

@@ -27,10 +27,8 @@ from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -82,26 +80,6 @@ class Answer:
and not skip_explicit_tool_calling
)
rerank_settings = search_request.rerank_settings
using_cloud_reranking = (
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.
# remove SearchQuery entirely)
if (
force_use_tool.force_use
and search_tool
and force_use_tool.args
and force_use_tool.tool_name == search_tool.name
and QUERY_FIELD in force_use_tool.args
):
search_request.query = force_use_tool.args[QUERY_FIELD]
self.graph_inputs = GraphInputs(
search_request=search_request,
prompt_builder=prompt_builder,
@@ -116,6 +94,7 @@ class Answer:
force_use_tool=force_use_tool,
using_tool_calling_llm=using_tool_calling_llm,
)
assert db_session, "db_session must be provided for agentic persistence"
self.graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=chat_session_id,
@@ -125,7 +104,6 @@ class Answer:
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
allow_refinement=True,
allow_agent_reranking=allow_agent_reranking,
)
self.graph_config = GraphConfig(
inputs=self.graph_inputs,

View File

@@ -190,8 +190,7 @@ def create_chat_chain(
and previous_message.message_type == MessageType.ASSISTANT
and mainline_messages
):
if current_message.refined_answer_improvement:
mainline_messages[-1] = current_message
mainline_messages[-1] = current_message
else:
mainline_messages.append(current_message)

View File

@@ -142,15 +142,6 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class AgentMessageIDInfo(BaseModel):
level: int
message_id: int
class AgenticMessageResponseIDInfo(BaseModel):
agentic_message_ids: list[AgentMessageIDInfo]
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None

View File

@@ -7,12 +7,10 @@ from typing import cast
from sqlalchemy.orm import Session
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
from onyx.chat.models import AgenticMessageResponseIDInfo
from onyx.chat.models import AgentMessageIDInfo
from onyx.chat.models import AgentSearchPacket
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerPostInfo
@@ -145,10 +143,9 @@ from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import get_current_tenant_id
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
def _translate_citations(
@@ -310,7 +307,6 @@ ChatPacket = (
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| AgenticMessageResponseIDInfo
| StreamStopInfo
| AgentSearchPacket
)
@@ -346,7 +342,7 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
tenant_id = get_current_tenant_id()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
@@ -635,7 +631,6 @@ def stream_chat_message_objects(
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
@@ -747,13 +742,14 @@ def stream_chat_message_objects(
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config, llm.config),
system_message=default_build_system_message(prompt_config),
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
raw_user_uploaded_files=latest_query_files or [],
single_message_history=single_message_history,
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
# LLM prompt building, response capturing, etc.
answer = Answer(
@@ -869,6 +865,7 @@ def stream_chat_message_objects(
for img in img_generation_response
if img.image_data
],
tenant_id=tenant_id,
)
info.ai_message_files.extend(
[
@@ -1018,7 +1015,7 @@ def stream_chat_message_objects(
if info.message_specific_citations
else None
),
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
error=None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
@@ -1036,7 +1033,6 @@ def stream_chat_message_objects(
next_level = 1
prev_message = gen_ai_response_message
agent_answers = answer.llm_answer_by_level()
agentic_message_ids = []
while next_level in agent_answers:
next_answer = agent_answers[next_level]
info = info_by_subq[
@@ -1057,12 +1053,7 @@ def stream_chat_message_objects(
citations=info.message_specific_citations.citation_map
if info.message_specific_citations
else None,
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
refined_answer_improvement=refined_answer_improvement,
is_agentic=True,
)
agentic_message_ids.append(
AgentMessageIDInfo(level=next_level, message_id=next_answer_message.id)
)
next_level += 1
prev_message = next_answer_message
@@ -1070,9 +1061,11 @@ def stream_chat_message_objects(
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield translate_db_message_to_chat_message_detail(gen_ai_response_message)
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)

View File

@@ -12,7 +12,6 @@ from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_toke
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLMConfig
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
@@ -20,7 +19,6 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
@@ -33,16 +31,8 @@ from onyx.tools.tool import Tool
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
# for o-series markdown generation
if (
llm_config.model_provider == OPENAI_PROVIDER_NAME
and llm_config.model_name.startswith("o")
):
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
@@ -120,8 +110,21 @@ class AnswerPromptBuilder:
),
)
self.update_system_prompt(system_message)
self.update_user_prompt(user_message)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
(
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
if system_message
else None
)
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []

View File

@@ -31,9 +31,22 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
AGENT_ANSWER_GENERATION_BY_FAST_LLM = (
os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true"
)
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
AGENT_RETRIEVAL_STATS = (
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
@@ -165,172 +178,80 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
) # 2000
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
)
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
) # 3
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
) # 30
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
) # 12
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
) # 6
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
) # 1
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
) # 4
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
)
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
) # 8
GRAPH_VERSION_NAME: str = "a"

View File

@@ -6,7 +6,6 @@ from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
#####
@@ -30,9 +29,6 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
)
#####
# Web Configs
@@ -173,11 +169,6 @@ POSTGRES_API_SERVER_POOL_SIZE = int(
POSTGRES_API_SERVER_POOL_OVERFLOW = int(
os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10
)
# defaults to False
# generally should only be used for
POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true"
# defaults to False
POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true"
@@ -630,10 +621,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
# Set to true to mock LLM responses for testing purposes

View File

@@ -98,18 +98,9 @@ CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# hard timeout applied by the watchdog to the indexing connector run
# to handle hung connectors
CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
# soft timeout for the lock taken by the indexing connector run
# allows the lock to eventually expire if the managing code around it dies
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
# CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 15 minutes
# hard termination should always fire first if the connector is hung
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
@@ -174,9 +165,6 @@ class DocumentSource(str, Enum):
EGNYTE = "egnyte"
AIRTABLE = "airtable"
# Special case just for integration tests
MOCK_CONNECTOR = "mock_connector"
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
@@ -213,12 +201,6 @@ class AuthType(str, Enum):
CLOUD = "cloud"
class QueryHistoryType(str, Enum):
DISABLED = "disabled"
ANONYMIZED = "anonymized"
NORMAL = "normal"
# Special characters for password validation
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
@@ -261,7 +243,6 @@ class FileOrigin(str, Enum):
CHAT_IMAGE_GEN = "chat_image_gen"
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
INDEXING_CHECKPOINT = "indexing_checkpoint"
OTHER = "other"
@@ -293,7 +274,6 @@ class OnyxCeleryQueues:
DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert"
CONNECTOR_DELETION = "connector_deletion"
LLM_MODEL_UPDATE = "llm_model_update"
CHECKPOINT_CLEANUP = "checkpoint_cleanup"
# Heavy queue
CONNECTOR_PRUNING = "connector_pruning"
@@ -313,7 +293,6 @@ class OnyxRedisLocks:
CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat"
CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat"
CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat"
CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat"
CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = (
"da_lock:check_connector_doc_permissions_sync_beat"
)
@@ -348,9 +327,6 @@ class OnyxRedisSignals:
BLOCK_PRUNING = "signal:block_pruning"
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
"signal:block_validate_connector_deletion_fences"
)
class OnyxRedisConstants:
@@ -392,10 +368,6 @@ class OnyxCeleryTask:
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
CHECK_FOR_LLM_MODEL_UPDATE = "check_for_llm_model_update"
# Connector checkpoint cleanup
CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup"
CLEANUP_CHECKPOINT = "cleanup_checkpoint"
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"

View File

@@ -245,7 +245,7 @@ class AirtableConnector(LoadConnector):
return [(" ".join(combined) if combined else str(field_info), default_link)]
if isinstance(field_info, list):
return [(str(item), default_link) for item in field_info]
return [(item, default_link) for item in field_info]
return [(str(field_info), default_link)]
@@ -268,7 +268,7 @@ class AirtableConnector(LoadConnector):
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[Section], dict[str, str | list[str]]]:
) -> tuple[list[Section], dict[str, Any]]:
"""
Process a single Airtable field and return sections or metadata.
@@ -342,7 +342,7 @@ class AirtableConnector(LoadConnector):
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
metadata: dict[str, str | list[str]] = {}
metadata: dict[str, Any] = {}
# Get primary field value if it exists
primary_field_value = (

View File

@@ -7,18 +7,11 @@ from typing import Optional
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from botocore.exceptions import PartialCredentialsError
from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import BlobType
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -247,73 +240,6 @@ class BlobStorageConnector(LoadConnector, PollConnector):
return None
def validate_connector_settings(self) -> None:
if self.s3_client is None:
raise ConnectorMissingCredentialError(
"Blob storage credentials not loaded."
)
if not self.bucket_name:
raise ConnectorValidationError(
"No bucket name was provided in connector settings."
)
try:
# We only fetch one object/page as a light-weight validation step.
# This ensures we trigger typical S3 permission checks (ListObjectsV2, etc.).
self.s3_client.list_objects_v2(
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
)
except NoCredentialsError:
raise ConnectorMissingCredentialError(
"No valid blob storage credentials found or provided to boto3."
)
except PartialCredentialsError:
raise ConnectorMissingCredentialError(
"Partial or incomplete blob storage credentials provided to boto3."
)
except ClientError as e:
error_code = e.response["Error"].get("Code", "")
status_code = e.response["ResponseMetadata"].get("HTTPStatusCode")
# Most common S3 error cases
if error_code in [
"AccessDenied",
"InvalidAccessKeyId",
"SignatureDoesNotMatch",
]:
if status_code == 403 or error_code == "AccessDenied":
raise InsufficientPermissionsError(
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
"Please check your bucket policy and/or IAM policy."
)
if status_code == 401 or error_code == "SignatureDoesNotMatch":
raise CredentialExpiredError(
"Provided blob storage credentials appear invalid or expired."
)
raise CredentialExpiredError(
f"Credential issue encountered ({error_code})."
)
if error_code == "NoSuchBucket" or status_code == 404:
raise ConnectorValidationError(
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
)
raise ConnectorValidationError(
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
)
except Exception as e:
# Catch-all for anything not captured by the above
# Since we are unsure of the error and it may not disable the connector,
# raise an unexpected error (does not disable connector)
raise UnexpectedError(
f"Unexpected error during blob storage settings validation: {e}"
)
if __name__ == "__main__":
credentials_dict = {

View File

@@ -5,8 +5,6 @@ import requests
class BookStackClientRequestFailedError(ConnectionError):
def __init__(self, status: int, error: str) -> None:
self.status_code = status
self.error = error
super().__init__(
"BookStack Client request failed with status {status}: {error}".format(
status=status, error=error

View File

@@ -7,11 +7,7 @@ from typing import Any
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.bookstack.client import BookStackApiClient
from onyx.connectors.bookstack.client import BookStackClientRequestFailedError
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -218,39 +214,3 @@ class BookstackConnector(LoadConnector, PollConnector):
break
else:
time.sleep(0.2)
def validate_connector_settings(self) -> None:
"""
Validate that the BookStack credentials and connector settings are correct.
Specifically checks that we can make an authenticated request to BookStack.
"""
if not self.bookstack_client:
raise ConnectorMissingCredentialError(
"BookStack credentials have not been loaded."
)
try:
# Attempt to fetch a small batch of books (arbitrary endpoint) to verify credentials
_ = self.bookstack_client.get(
"/books", params={"count": "1", "offset": "0"}
)
except BookStackClientRequestFailedError as e:
# Check for HTTP status codes
if e.status_code == 401:
raise CredentialExpiredError(
"Your BookStack credentials appear to be invalid or expired (HTTP 401)."
) from e
elif e.status_code == 403:
raise InsufficientPermissionsError(
"The configured BookStack token does not have sufficient permissions (HTTP 403)."
) from e
else:
raise ConnectorValidationError(
f"Unexpected BookStack error (status={e.status_code}): {e}"
) from e
except Exception as exc:
raise ConnectorValidationError(
f"Unexpected error while validating BookStack connector settings: {exc}"
) from exc

View File

@@ -4,8 +4,6 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from requests.exceptions import HTTPError
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -18,10 +16,6 @@ from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -403,33 +397,3 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list
def validate_connector_settings(self) -> None:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence credentials not loaded.")
try:
spaces = self._confluence_client.get_all_spaces(limit=1)
except HTTPError as e:
status_code = e.response.status_code if e.response else None
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Confluence credentials (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Insufficient permissions to access Confluence resources (HTTP 403)."
)
raise UnexpectedError(
f"Unexpected Confluence error (status={status_code}): {e}"
)
except Exception as e:
raise UnexpectedError(
f"Unexpected error while validating Confluence settings: {e}"
)
if not spaces or not spaces.get("results"):
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
"there truly are no spaces in this Confluence instance."
)

Some files were not shown because too many files have changed in this diff Show More