mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 09:15:47 +00:00
Compare commits
102 Commits
readme-tip
...
gmail
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ddaaaeeb40 | ||
|
|
4e31bc19dc | ||
|
|
1c0c80116e | ||
|
|
27891ad34d | ||
|
|
511341fd7c | ||
|
|
88d4e7defa | ||
|
|
a7ba0da8cc | ||
|
|
aaced6d551 | ||
|
|
4c230f92ea | ||
|
|
07d75b04d1 | ||
|
|
a8d10750c1 | ||
|
|
85e3ed57f1 | ||
|
|
e10cc8ccdb | ||
|
|
7018bc974b | ||
|
|
9c9075d71d | ||
|
|
338e084062 | ||
|
|
2f64031f5c | ||
|
|
abb74f2eaa | ||
|
|
a3e3d83b7e | ||
|
|
4dc88ca037 | ||
|
|
11e7e1c4d6 | ||
|
|
f2d74ce540 | ||
|
|
25389c5120 | ||
|
|
ad0721ecd8 | ||
|
|
426a8842ae | ||
|
|
a98dcbc7de | ||
|
|
6f389dc100 | ||
|
|
d56177958f | ||
|
|
0e42ae9024 | ||
|
|
ce2b4de245 | ||
|
|
a515aa78d2 | ||
|
|
23073d91b9 | ||
|
|
f767b1f476 | ||
|
|
9ffc8cb2c4 | ||
|
|
98bfb58147 | ||
|
|
6ce810e957 | ||
|
|
07b0b57b31 | ||
|
|
118cdd7701 | ||
|
|
ac83b4c365 | ||
|
|
fa408ff447 | ||
|
|
4aa8eb8b75 | ||
|
|
60bd9271f7 | ||
|
|
5d58a5e3ea | ||
|
|
a99dd05533 | ||
|
|
0dce67094e | ||
|
|
ffd14435a4 | ||
|
|
c9a3b45ad4 | ||
|
|
7d40676398 | ||
|
|
b9e79e5db3 | ||
|
|
558bbe16e4 | ||
|
|
076619ce2c | ||
|
|
1263e21eb5 | ||
|
|
f0c13b6558 | ||
|
|
a7125662f1 | ||
|
|
4a4e4a6c50 | ||
|
|
1f2af373e1 | ||
|
|
bdaa293ae4 | ||
|
|
5a131f4547 | ||
|
|
ffb7d5b85b | ||
|
|
fe8a5d671a | ||
|
|
6de53ebf60 | ||
|
|
61d536c782 | ||
|
|
e1ff9086a4 | ||
|
|
ba21bacbbf | ||
|
|
158bccc3fc | ||
|
|
599b7705c2 | ||
|
|
4958a5355d | ||
|
|
c4b8519381 | ||
|
|
8b4413694a | ||
|
|
57cf7d9fac | ||
|
|
ad4efb5f20 | ||
|
|
e304ec4ab6 | ||
|
|
1690dc45ba | ||
|
|
7582ba1640 | ||
|
|
99fc546943 | ||
|
|
353c185856 | ||
|
|
7c96b7f24e | ||
|
|
31524a3eff | ||
|
|
c9f618798e | ||
|
|
11f6b44625 | ||
|
|
e82a25f49e | ||
|
|
5a9ec61446 | ||
|
|
9635522de8 | ||
|
|
630bdf71a3 | ||
|
|
47fd4fa233 | ||
|
|
2013beb9e0 | ||
|
|
466276161c | ||
|
|
c934892c68 | ||
|
|
1daa3a663d | ||
|
|
7324273233 | ||
|
|
2b2ba5478c | ||
|
|
045a41d929 | ||
|
|
e3bc7cc747 | ||
|
|
0826b035a2 | ||
|
|
cf0e3d1ff4 | ||
|
|
10c81f75e2 | ||
|
|
5ca898bde2 | ||
|
|
58b252727f | ||
|
|
86bd121806 | ||
|
|
9324f426c0 | ||
|
|
20d3efc86e | ||
|
|
ec0e55fd39 |
94
.github/workflows/nightly-scan-licenses.yml
vendored
94
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -53,24 +53,90 @@ jobs:
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
if: always()
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@0.29.0
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
5
.github/workflows/pr-integration-tests.yml
vendored
5
.github/workflows/pr-integration-tests.yml
vendored
@@ -145,7 +145,7 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -157,6 +157,7 @@ jobs:
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
@@ -199,7 +200,7 @@ jobs:
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
|
||||
@@ -74,7 +74,9 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
97
.github/workflows/pr-python-model-tests.yml
vendored
97
.github/workflows/pr-python-model-tests.yml
vendored
@@ -1,18 +1,29 @@
|
||||
name: Connector Tests
|
||||
name: Model Server Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to run the workflow on'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
|
||||
|
||||
# OpenAI
|
||||
# API keys for testing
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
|
||||
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
@@ -26,6 +37,23 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Model Server Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-model-server:latest
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -41,6 +69,49 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
@@ -56,3 +127,23 @@ jobs:
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
26
README.md
26
README.md
@@ -26,12 +26,12 @@
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
|
||||
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
|
||||
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
|
||||
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
|
||||
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
|
||||
|
||||
|
||||
<h3>Feature Showcase</h3>
|
||||
<h3>Feature Highlights</h3>
|
||||
|
||||
**Deep research over your team's knowledge:**
|
||||
|
||||
@@ -54,8 +54,7 @@ https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95
|
||||
|
||||
|
||||
## Deployment
|
||||
> [!TIP]
|
||||
> To try it out for free and get started in seconds, check out **[Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
@@ -64,22 +63,21 @@ We also have built-in support for high-availability/scalable deployment on Kuber
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🚧 Roadmap
|
||||
- Extensions to the Chrome Plugin
|
||||
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Personalized Search
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
- Code Search
|
||||
- SQL and Structured Query Language
|
||||
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models only through Onyx + learn from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
Keep knowledge and access up to sync across 40+ connectors:
|
||||
|
||||
|
||||
@@ -28,11 +28,11 @@ RUN apt-get update && \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libgnutls30 \
|
||||
libblkid1 \
|
||||
libmount1 \
|
||||
libsmartcols1 \
|
||||
libuuid1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc \
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add indexes to document__tag
|
||||
|
||||
Revision ID: 1a03d2c2856b
|
||||
Revises: 9c00a2bccb83
|
||||
Create Date: 2025-02-18 10:45:13.957807
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1a03d2c2856b"
|
||||
down_revision = "9c00a2bccb83"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_document__tag_tag_id"),
|
||||
"document__tag",
|
||||
["tag_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")
|
||||
84
backend/alembic/versions/3bd4c84fe72f_improved_index.py
Normal file
84
backend/alembic/versions/3bd4c84fe72f_improved_index.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""improved index
|
||||
|
||||
Revision ID: 3bd4c84fe72f
|
||||
Revises: 8f43500ee275
|
||||
Create Date: 2025-02-26 13:07:56.217791
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3bd4c84fe72f"
|
||||
down_revision = "8f43500ee275"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# NOTE:
|
||||
# This migration addresses issues with the previous migration (8f43500ee275) which caused
|
||||
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||
#
|
||||
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
ADD COLUMN message_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit the current transaction before creating concurrent indexes
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message
|
||||
USING GIN (message_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
# Also add a stored tsvector column for chat_session.description
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_session
|
||||
ADD COLUMN description_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit again before creating the second concurrent index
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session
|
||||
USING GIN (description_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
# Then drop the columns
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
32
backend/alembic/versions/8f43500ee275_add_index.py
Normal file
32
backend/alembic/versions/8f43500ee275_add_index.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""add index
|
||||
|
||||
Revision ID: 8f43500ee275
|
||||
Revises: da42808081e3
|
||||
Create Date: 2025-02-24 17:35:33.072714
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8f43500ee275"
|
||||
down_revision = "da42808081e3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
# op.execute(
|
||||
# """
|
||||
# CREATE INDEX idx_chat_message_message_lower
|
||||
# ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
# """
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the index
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -0,0 +1,43 @@
|
||||
"""chat_message_agentic
|
||||
|
||||
Revision ID: 9c00a2bccb83
|
||||
Revises: b7a7eee5aa15
|
||||
Create Date: 2025-02-17 11:15:43.081150
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c00a2bccb83"
|
||||
down_revision = "b7a7eee5aa15"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First add the column as nullable
|
||||
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
|
||||
|
||||
# Update existing rows based on presence of SubQuestions
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET is_agentic = EXISTS (
|
||||
SELECT 1
|
||||
FROM agent__sub_question
|
||||
WHERE agent__sub_question.primary_question_id = chat_message.id
|
||||
)
|
||||
WHERE is_agentic IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the column non-nullable with a default value of False
|
||||
op.alter_column(
|
||||
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_agentic")
|
||||
@@ -0,0 +1,29 @@
|
||||
"""remove inactive ccpair status on downgrade
|
||||
|
||||
Revision ID: acaab4ef4507
|
||||
Revises: b388730a2899
|
||||
Create Date: 2025-02-16 18:21:41.330212
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from sqlalchemy import update
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "acaab4ef4507"
|
||||
down_revision = "b388730a2899"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
|
||||
.values(status=ConnectorCredentialPairStatus.ACTIVE)
|
||||
)
|
||||
@@ -0,0 +1,31 @@
|
||||
"""nullable preferences
|
||||
|
||||
Revision ID: b388730a2899
|
||||
Revises: 1a03d2c2856b
|
||||
Create Date: 2025-02-17 18:49:22.643902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b388730a2899"
|
||||
down_revision = "1a03d2c2856b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=True)
|
||||
op.alter_column("user", "auto_scroll", nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Ensure no null values before making columns non-nullable
|
||||
op.execute(
|
||||
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
|
||||
)
|
||||
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
|
||||
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=False)
|
||||
op.alter_column("user", "auto_scroll", nullable=False)
|
||||
@@ -0,0 +1,120 @@
|
||||
"""migrate jira connectors to new format
|
||||
|
||||
Revision ID: da42808081e3
|
||||
Revises: f13db29f3101
|
||||
Create Date: 2025-02-24 11:24:54.396040
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "da42808081e3"
|
||||
down_revision = "f13db29f3101"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all Jira connectors
|
||||
jira_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = :source
|
||||
"""
|
||||
),
|
||||
{"source": DocumentSource.JIRA.value.upper()},
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
for connector_id, old_config in jira_connectors:
|
||||
if not old_config:
|
||||
continue
|
||||
|
||||
# Extract project key from URL if it exists
|
||||
new_config: dict[str, str | None] = {}
|
||||
if project_url := old_config.get("jira_project_url"):
|
||||
# Parse the URL to get base and project
|
||||
try:
|
||||
jira_base, project_key = extract_jira_project(project_url)
|
||||
new_config = {"jira_base_url": jira_base, "project_key": project_key}
|
||||
except ValueError:
|
||||
# If URL parsing fails, just use the URL as the base
|
||||
new_config = {
|
||||
"jira_base_url": project_url.split("/projects/")[0],
|
||||
"project_key": None,
|
||||
}
|
||||
else:
|
||||
# For connectors without a project URL, we need admin intervention
|
||||
# Mark these for review
|
||||
print(
|
||||
f"WARNING: Jira connector {connector_id} has no project URL configured"
|
||||
)
|
||||
continue
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all Jira connectors
|
||||
jira_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = :source
|
||||
"""
|
||||
),
|
||||
{"source": DocumentSource.JIRA.value.upper()},
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config back to the old format
|
||||
for connector_id, new_config in jira_connectors:
|
||||
if not new_config:
|
||||
continue
|
||||
|
||||
old_config = {}
|
||||
base_url = new_config.get("jira_base_url")
|
||||
project_key = new_config.get("project_key")
|
||||
|
||||
if base_url and project_key:
|
||||
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
|
||||
elif base_url:
|
||||
old_config = {"jira_project_url": base_url}
|
||||
else:
|
||||
continue
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :old_config
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "old_config": old_config},
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
"""force lowercase all users
|
||||
|
||||
Revision ID: f11b408e39d3
|
||||
Revises: 3bd4c84fe72f
|
||||
Create Date: 2025-02-26 17:04:55.683500
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f11b408e39d3"
|
||||
down_revision = "3bd4c84fe72f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing user emails to lowercase
|
||||
from alembic import op
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
|
||||
# 2) Add a check constraint to ensure emails are always lowercase
|
||||
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
from alembic import op
|
||||
|
||||
op.drop_constraint("ensure_lowercase_email", "user", type_="check")
|
||||
@@ -0,0 +1,27 @@
|
||||
"""Add composite index for last_modified and last_synced to document
|
||||
|
||||
Revision ID: f13db29f3101
|
||||
Revises: b388730a2899
|
||||
Create Date: 2025-02-18 22:48:11.511389
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f13db29f3101"
|
||||
down_revision = "acaab4ef4507"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_document_sync_status",
|
||||
"document",
|
||||
["last_modified", "last_synced"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_document_sync_status", table_name="document")
|
||||
@@ -0,0 +1,42 @@
|
||||
"""lowercase multi-tenant user auth
|
||||
|
||||
Revision ID: 34e3630c7f32
|
||||
Revises: a4f6ee863c47
|
||||
Create Date: 2025-02-26 15:03:01.211894
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "34e3630c7f32"
|
||||
down_revision = "a4f6ee863c47"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing rows to lowercase
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE user_tenant_mapping
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
# 2) Add a check constraint so that emails cannot be written in uppercase
|
||||
op.create_check_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
"email = LOWER(email)",
|
||||
schema="public",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
op.drop_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
schema="public",
|
||||
type_="check",
|
||||
)
|
||||
@@ -5,11 +5,9 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.db.chat import delete_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -18,10 +16,8 @@ logger = setup_logger()
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -35,24 +31,19 @@ def perform_ttl_management_task(
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -60,9 +51,9 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
@@ -35,10 +36,11 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
|
||||
def get_cc_pairs_by_source(
|
||||
db_session: Session,
|
||||
source_type: DocumentSource,
|
||||
only_sync: bool,
|
||||
access_type: AccessType | None = None,
|
||||
status: ConnectorCredentialPairStatus | None = None,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""
|
||||
Get all cc_pairs for a given source type (and optionally only sync)
|
||||
Get all cc_pairs for a given source type with optional filtering by access_type and status
|
||||
result is sorted by cc_pair id
|
||||
"""
|
||||
query = (
|
||||
@@ -48,8 +50,11 @@ def get_cc_pairs_by_source(
|
||||
.order_by(ConnectorCredentialPair.id)
|
||||
)
|
||||
|
||||
if only_sync:
|
||||
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
|
||||
if access_type is not None:
|
||||
query = query.filter(ConnectorCredentialPair.access_type == access_type)
|
||||
|
||||
if status is not None:
|
||||
query = query.filter(ConnectorCredentialPair.status == status)
|
||||
|
||||
cc_pairs = query.all()
|
||||
return cc_pairs
|
||||
|
||||
@@ -14,30 +14,24 @@ def _build_group_member_email_map(
|
||||
confluence_client: OnyxConfluence, cc_pair_id: int
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
for user in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user}")
|
||||
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
msg = f"user result missing user field: {user_result}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
email = user.get("email")
|
||||
email = user.email
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
user_name = user.get("username")
|
||||
user_name = user.username
|
||||
# If it is present, try to get the email using a Server-specific method
|
||||
if user_name:
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
msg = f"user result missing email field: {user_result}"
|
||||
if user.get("type") == "app":
|
||||
msg = f"user result missing email field: {user}"
|
||||
if user.type == "app":
|
||||
logger.warning(msg)
|
||||
else:
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
@@ -45,7 +39,7 @@ def _build_group_member_email_map(
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
|
||||
@@ -62,12 +62,14 @@ def _fetch_permissions_for_permission_ids(
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=doc_id,
|
||||
fields="permissions(id, emailAddress, type, domain)",
|
||||
supportsAllDrives=True,
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
@@ -104,7 +106,13 @@ def _get_permissions_from_slim_doc(
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
public = False
|
||||
skipped_permissions = 0
|
||||
|
||||
for permission in permissions_list:
|
||||
if not permission:
|
||||
skipped_permissions += 1
|
||||
continue
|
||||
|
||||
permission_type = permission["type"]
|
||||
if permission_type == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
@@ -121,6 +129,11 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
if skipped_permissions > 0:
|
||||
logger.warning(
|
||||
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in tenant ID middleware: {str(e)}")
|
||||
logger.exception(f"Error in tenant ID middleware: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ async def _get_tenant_id_from_request(
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id:
|
||||
if tenant_id is not None:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
|
||||
@@ -36,12 +36,12 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -271,12 +271,12 @@ def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
@@ -329,7 +329,6 @@ def handle_slack_oauth_callback(
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
@@ -337,7 +336,7 @@ def handle_slack_oauth_callback(
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
@@ -523,7 +522,6 @@ def handle_google_drive_oauth_callback(
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
@@ -531,7 +529,7 @@ def handle_google_drive_oauth_callback(
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TokenRateLimit
|
||||
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
|
||||
def _check_token_rate_limits(user: User | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
elif is_api_key_email_address(user.email):
|
||||
# API keys are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
_user_is_rate_limited_by_global()
|
||||
|
||||
else:
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(_user_is_rate_limited, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_global, (tenant_id,)),
|
||||
(_user_is_rate_limited, (user.id,)),
|
||||
(_user_is_rate_limited_by_group, (user.id,)),
|
||||
(_user_is_rate_limited_by_global, ()),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -52,8 +52,8 @@ User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
def _user_is_rate_limited(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
@@ -93,8 +93,8 @@ User Group rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
|
||||
@@ -2,6 +2,7 @@ import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -21,8 +22,10 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
@@ -35,6 +38,8 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
@@ -107,6 +112,17 @@ def get_user_chat_sessions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session, limit=0
|
||||
@@ -122,6 +138,7 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
@@ -141,6 +158,12 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
@@ -157,11 +180,16 @@ def get_chat_session_history(
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
minimal_chat_sessions: list[ChatSessionMinimal] = []
|
||||
|
||||
for chat_session in page_of_chat_sessions:
|
||||
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
minimal_chat_sessions.append(minimal_chat_session)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
items=minimal_chat_sessions,
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
@@ -172,6 +200,12 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -193,6 +227,9 @@ def get_chat_session_admin(
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -203,6 +240,12 @@ def get_query_history_as_csv(
|
||||
end: datetime | None = None,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
@@ -213,6 +256,9 @@ def get_query_history_as_csv(
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
)
|
||||
|
||||
@@ -41,14 +41,15 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
@@ -57,13 +58,14 @@ router = APIRouter(prefix="/tenants")
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
@@ -72,15 +74,15 @@ async def get_anonymous_user_path_api(
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
@@ -101,7 +103,7 @@ async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
@@ -150,14 +152,17 @@ async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
# Fetch tenant_id and current tenant's information
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
@@ -181,6 +186,8 @@ async def create_subscription_session(
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
@@ -197,7 +204,7 @@ async def impersonate_user(
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as tenant_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
@@ -221,8 +228,9 @@ async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
|
||||
@@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
def fetch_billing_information(
|
||||
tenant_id: str,
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the response indicates no subscription
|
||||
if (
|
||||
isinstance(response_data, dict)
|
||||
and "subscribed" in response_data
|
||||
and not response_data["subscribed"]
|
||||
):
|
||||
return SubscriptionStatusResponse(**response_data)
|
||||
|
||||
# Otherwise, parse as BillingInformation
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
|
||||
@@ -118,7 +118,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
@@ -134,7 +134,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
@@ -200,33 +200,15 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-5-sonnet-20241022",
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
@@ -238,6 +220,26 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -98,12 +98,17 @@ class CloudEmbedding:
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Error embedding text with OpenAI: {str(e)} \n"
|
||||
f"Model: {model} \n"
|
||||
f"Provider: {self.provider} \n"
|
||||
f"Texts: {texts}"
|
||||
f"Exception embedding text with OpenAI - {type(e)}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {self.provider} "
|
||||
f"Exception: {e}"
|
||||
)
|
||||
logger.error(error_string)
|
||||
|
||||
# only log text when it's not an authentication error.
|
||||
if not isinstance(e, openai.AuthenticationError):
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
async def _embed_cohere(
|
||||
|
||||
@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
node="choose_tool",
|
||||
action=choose_tool,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
start_key="call_tool",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "tool_call"
|
||||
else "call_tool"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,12 +31,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -85,9 +87,11 @@ def check_sub_answer(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = fast_llm.invoke(
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
@@ -96,7 +100,7 @@ def check_sub_answer(
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except LLMTimeoutError:
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
@@ -47,11 +46,13 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -110,15 +111,14 @@ def generate_sub_answer(
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
try:
|
||||
def stream_sub_answer() -> list[str]:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -142,8 +142,15 @@ def generate_sub_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
except LLMTimeoutError:
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -60,11 +59,15 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -77,6 +80,7 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
@@ -230,7 +234,11 @@ def generate_initial_answer(
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = trim_prompt_piece(
|
||||
@@ -260,15 +268,16 @@ def generate_initial_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
streamed_tokens: list[str] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -292,9 +301,16 @@ def generate_initial_answer(
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
except LLMTimeoutError:
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -36,7 +36,10 @@ from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -47,6 +50,7 @@ from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -131,10 +135,12 @@ def decompose_orig_question(
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
try:
|
||||
streamed_tokens = dispatch_separated(
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
@@ -154,7 +160,7 @@ def decompose_orig_question(
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
except LLMTimeoutError as e:
|
||||
except (LLMTimeoutError, TimeoutError) as e:
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
|
||||
"""
|
||||
LangGraph edge to route to agent search.
|
||||
"""
|
||||
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
|
||||
):
|
||||
return "start_agent_search"
|
||||
else:
|
||||
return "tool_call"
|
||||
return "call_tool"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
# Choose the initial tool
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
action=choose_tool,
|
||||
)
|
||||
|
||||
# Call the tool, if required
|
||||
graph.add_node(
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
)
|
||||
|
||||
# Use the tool response
|
||||
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["tool_call", "start_agent_search", "logging_node"],
|
||||
["call_tool", "start_agent_search", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="tool_call",
|
||||
start_key="call_tool",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
|
||||
@@ -33,13 +33,15 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -105,11 +107,14 @@ def compare_answers(
|
||||
refined_answer_improvement: bool | None = None
|
||||
# no need to stream this
|
||||
try:
|
||||
resp = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
resp = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
|
||||
model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
|
||||
)
|
||||
|
||||
except LLMTimeoutError:
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -44,7 +44,10 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -53,6 +56,7 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -134,15 +138,17 @@ def create_refined_sub_questions(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = dispatch_separated(
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -22,11 +22,17 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -84,30 +90,42 @@ def extract_entities_terms(
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
llm_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
except ValueError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - extract entities terms")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - extract entities terms")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -66,14 +65,21 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -92,6 +98,7 @@ from onyx.prompts.agent_search import (
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -253,7 +260,12 @@ def generate_validate_refined_answer(
|
||||
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
@@ -284,13 +296,13 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
streamed_tokens: list[str] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
def stream_refined_answer() -> list[str]:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -315,8 +327,15 @@ def generate_validate_refined_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
return streamed_tokens
|
||||
|
||||
except LLMTimeoutError:
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
stream_refined_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
@@ -383,16 +402,20 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
validation_model = graph_config.tooling.fast_llm
|
||||
try:
|
||||
validation_response = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
validation_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
separator=AGENT_ANSWER_SEPARATOR,
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Timeout Error - validate refined answer")
|
||||
|
||||
|
||||
@@ -34,14 +34,16 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -69,7 +71,7 @@ def expand_queries(
|
||||
node_start_time = datetime.now()
|
||||
question = state.question
|
||||
|
||||
llm = graph_config.tooling.fast_llm
|
||||
model = graph_config.tooling.fast_llm
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_num = 0, 0
|
||||
@@ -88,10 +90,12 @@ def expand_queries(
|
||||
rewritten_queries = []
|
||||
|
||||
try:
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(
|
||||
llm_response_list = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
@@ -101,7 +105,7 @@ def expand_queries(
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
|
||||
except LLMTimeoutError:
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -55,6 +55,7 @@ def rerank_documents(
|
||||
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
|
||||
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
@@ -62,23 +63,31 @@ def rerank_documents(
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Initial default: no reranking. Will be overwritten below if reranking is warranted
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
if not allow_agent_reranking:
|
||||
logger.info("Use of local rerank model without GPU, skipping reranking")
|
||||
# No reranking, stay with verified_documents as default
|
||||
|
||||
else:
|
||||
# Reranking is warranted, use the rerank_sections functon
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
reranked_documents = verified_documents
|
||||
# No reranking, stay with verified_documents as default
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
reranked_documents = verified_documents
|
||||
|
||||
# No reranking, stay with verified_documents as default
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
|
||||
@@ -25,13 +25,15 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -86,8 +88,11 @@ def verify_documents(
|
||||
] # default is to treat document as relevant
|
||||
|
||||
try:
|
||||
response = fast_llm.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
@@ -96,7 +101,7 @@ def verify_documents(
|
||||
):
|
||||
verified_documents = []
|
||||
|
||||
except LLMTimeoutError:
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
|
||||
@@ -67,6 +67,7 @@ class GraphSearchConfig(BaseModel):
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
||||
write_custom_event("basic_response", packet, writer)
|
||||
|
||||
|
||||
def tool_call(
|
||||
def call_tool(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def llm_tool_choice(
|
||||
def choose_tool(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -43,8 +43,9 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -80,6 +81,7 @@ from onyx.tools.tool_implementations.search.search_tool import SearchResponseSum
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -395,11 +397,13 @@ def summarize_history(
|
||||
)
|
||||
|
||||
try:
|
||||
history_response = llm.invoke(
|
||||
history_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
llm.invoke,
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except LLMTimeoutError:
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
@@ -35,8 +36,7 @@ class ApiKeyDescriptor(BaseModel):
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
if not MULTI_TENANT or not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
|
||||
@@ -2,6 +2,8 @@ import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from email.utils import make_msgid
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -10,8 +12,10 @@ from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
@@ -149,8 +153,9 @@ def send_email(
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
@@ -172,7 +177,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
@@ -187,36 +192,64 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
def send_user_email_invite(
|
||||
user_email: str, current_user: User, auth_type: AuthType
|
||||
) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
heading = "You've Been Invited!"
|
||||
message = (
|
||||
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
|
||||
# the exact action taken by the user, and thus the message, depends on the auth type
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
if auth_type == AuthType.CLOUD:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.BASIC:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.GOOGLE_OAUTH:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to login with Google "
|
||||
"and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to"
|
||||
" complete your registration.</p>"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid auth type: {auth_type}")
|
||||
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
|
||||
# text content is the fallback for clients that don't support HTML
|
||||
# not as critical, so not having special cases for each auth type
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
"You'll be asked to set a password or login with Google to complete your registration."
|
||||
)
|
||||
if auth_type == AuthType.CLOUD:
|
||||
text_content += "You'll be asked to set a password or login with Google to complete your registration."
|
||||
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
tenant_id: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
|
||||
@@ -42,4 +42,5 @@ def fetch_no_auth_user(
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
password_configured=False,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
@@ -86,7 +88,6 @@ from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -94,6 +95,7 @@ from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -103,15 +105,11 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -143,6 +141,30 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
return email or ""
|
||||
|
||||
|
||||
def generate_password() -> str:
|
||||
lowercase_letters = string.ascii_lowercase
|
||||
uppercase_letters = string.ascii_uppercase
|
||||
digits = string.digits
|
||||
special_characters = string.punctuation
|
||||
|
||||
# Ensure at least one of each required character type
|
||||
password = [
|
||||
secrets.choice(uppercase_letters),
|
||||
secrets.choice(digits),
|
||||
secrets.choice(special_characters),
|
||||
]
|
||||
|
||||
# Fill the rest with a mix of characters
|
||||
remaining_length = 12 - len(password)
|
||||
all_characters = lowercase_letters + uppercase_letters + digits + special_characters
|
||||
password.extend(secrets.choice(all_characters) for _ in range(remaining_length))
|
||||
|
||||
# Shuffle the password to randomize the position of the required characters
|
||||
random.shuffle(password)
|
||||
|
||||
return "".join(password)
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
@@ -192,8 +214,8 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
@@ -389,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
user: User
|
||||
user: User | None = None
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
@@ -398,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.get_by_email(account_email)
|
||||
user = await self.user_db.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
# Make sure user is not None before adding OAuth account
|
||||
if user is not None:
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
else:
|
||||
# This shouldn't happen since get_by_email would raise UserNotExists
|
||||
# but adding as a safeguard
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
@@ -417,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
# Add OAuth account only if user creation was successful
|
||||
if user is not None:
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create user account"
|
||||
)
|
||||
|
||||
else:
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
# User exists, update OAuth account if needed
|
||||
if user is not None: # Add explicit check
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# Ensure user is not None before proceeding
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to authenticate or create user"
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
@@ -531,7 +568,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
|
||||
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
@@ -595,6 +632,39 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
return user
|
||||
|
||||
async def reset_password_as_admin(self, user_id: uuid.UUID) -> str:
|
||||
"""Admin-only. Generate a random password for a user and return it."""
|
||||
user = await self.get(user_id)
|
||||
new_password = generate_password()
|
||||
await self._update(user, {"password": new_password})
|
||||
return new_password
|
||||
|
||||
async def change_password_if_old_matches(
|
||||
self, user: User, old_password: str, new_password: str
|
||||
) -> None:
|
||||
"""
|
||||
For normal users to change password if they know the old one.
|
||||
Raises 400 if old password doesn't match.
|
||||
"""
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
old_password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
# Raise some HTTPException (or your custom exception) if old password is invalid:
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid current password",
|
||||
)
|
||||
|
||||
# If the hash was upgraded behind the scenes, we can keep it before setting the new password:
|
||||
if updated_password_hash:
|
||||
user.hashed_password = updated_password_hash
|
||||
|
||||
# Now apply and validate the new password
|
||||
await self._update(user, {"password": new_password})
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
@@ -819,8 +889,9 @@ async def current_limited_user(
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
@@ -33,6 +34,7 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -58,13 +60,35 @@ else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
class TenantAwareTask(Task):
|
||||
"""A custom base Task that sets tenant_id in a contextvar before running."""
|
||||
|
||||
abstract = True # So Celery knows not to register this as a real task.
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# Grab tenant_id from the kwargs, or fallback to default if missing.
|
||||
tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Set the context var
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Actually run the task now
|
||||
try:
|
||||
return super().__call__(*args, **kwargs)
|
||||
finally:
|
||||
# Clear or reset after the task runs
|
||||
# so it does not leak into any subsequent tasks on the same worker process
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**kwds: Any,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -108,9 +132,9 @@ def on_task_postrun(
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
@@ -201,7 +225,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
r = get_shared_redis_client()
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
@@ -287,7 +311,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_redis_client(tenant_id=None)
|
||||
r = get_shared_redis_client()
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
@@ -439,24 +463,6 @@ class TenantContextFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def set_tenant_id(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Signal handler to set tenant ID in context var before task starts."""
|
||||
tenant_id = (
|
||||
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if kwargs
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def reset_tenant_id(
|
||||
sender: Any | None = None,
|
||||
|
||||
@@ -132,6 +132,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
|
||||
new_schedule[tenant_task_name] = tenant_task
|
||||
|
||||
return new_schedule
|
||||
@@ -256,3 +257,4 @@ def on_setup_logging(
|
||||
|
||||
|
||||
celery_app.conf.beat_scheduler = DynamicTenantScheduler
|
||||
celery_app.conf.task_default_base = app_base.TenantAwareTask
|
||||
|
||||
@@ -20,6 +20,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.heavy")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -21,6 +21,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.indexing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -23,6 +23,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.light")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -20,6 +20,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from onyx.db.engine import get_session_with_default_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
@@ -38,7 +38,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -47,6 +47,7 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.primary")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -101,7 +102,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_redis_client(tenant_id=None)
|
||||
r = get_shared_redis_client()
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
@@ -158,7 +159,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
@@ -234,7 +235,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
r = get_shared_redis_client()
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
|
||||
@@ -92,7 +92,8 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
"""This is a redis specific way to build a list of tasks in a queue and return them
|
||||
as a set.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_deletion_status(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str,
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str,
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
|
||||
@@ -8,16 +8,21 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
@@ -27,7 +32,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -52,6 +57,51 @@ class TaskDependencyError(RuntimeError):
|
||||
with connector deletion."""
|
||||
|
||||
|
||||
def revoke_tasks_blocking_deletion(
|
||||
redis_connector: RedisConnector, db_session: Session, app: Celery
|
||||
) -> None:
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
try:
|
||||
index_payload = redis_connector_index.payload
|
||||
if index_payload and index_payload.celery_task_id:
|
||||
app.control.revoke(index_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked indexing task {index_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking indexing task")
|
||||
|
||||
try:
|
||||
permissions_sync_payload = redis_connector.permissions.payload
|
||||
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
|
||||
app.control.revoke(permissions_sync_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
|
||||
try:
|
||||
prune_payload = redis_connector.prune.payload
|
||||
if prune_payload and prune_payload.celery_task_id:
|
||||
app.control.revoke(prune_payload.celery_task_id)
|
||||
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
|
||||
try:
|
||||
external_group_sync_payload = redis_connector.external_group_sync.payload
|
||||
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
|
||||
app.control.revoke(external_group_sync_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking external group sync task")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
ignore_result=True,
|
||||
@@ -59,32 +109,46 @@ class TaskDependencyError(RuntimeError):
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# these tasks should never overlap
|
||||
# Prevent this task from overlapping with itself
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating connector deletion fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
@@ -92,9 +156,38 @@ def check_for_connector_deletion_task(
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
# on the first error, we set a stop signal and revoke the dependent tasks
|
||||
# on subsequent errors, we hard reset blocking fences after our specified timeout
|
||||
# is exceeded
|
||||
task_logger.info(str(e))
|
||||
redis_connector.stop.set_fence(True)
|
||||
|
||||
if not redis_connector.stop.fenced:
|
||||
# one time revoke of celery tasks
|
||||
task_logger.info("Revoking any tasks blocking deletion.")
|
||||
revoke_tasks_blocking_deletion(
|
||||
redis_connector, db_session, self.app
|
||||
)
|
||||
redis_connector.stop.set_fence(True)
|
||||
redis_connector.stop.set_timeout()
|
||||
else:
|
||||
# stop signal already set
|
||||
if redis_connector.stop.timed_out:
|
||||
# waiting too long, just reset blocking fences
|
||||
task_logger.info(
|
||||
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
|
||||
)
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
redis_connector.prune.reset()
|
||||
redis_connector.permissions.reset()
|
||||
redis_connector.external_group_sync.reset()
|
||||
else:
|
||||
# just wait
|
||||
pass
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
@@ -129,7 +222,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
@@ -169,6 +262,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.delete.set_active()
|
||||
fence_payload = RedisConnectorDeletePayload(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
@@ -249,7 +343,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
tenant_id: str, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -277,7 +371,7 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -287,7 +381,7 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -401,3 +495,171 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
|
||||
# validate all existing connector deletion jobs
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_connector_deletion_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
r,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.delete.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.delete.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.delete.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_connector_deletion_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.delete.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.delete.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
@@ -42,10 +43,12 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -63,6 +66,7 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -119,13 +123,13 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# TODO(rkuo): merge into check function after lookup table for fences is added
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -140,7 +144,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
try:
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
@@ -189,16 +193,23 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
task_logger.info(f"check_for_doc_permissions_sync finished: tenant={tenant_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id}"
|
||||
)
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -210,7 +221,7 @@ def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -247,7 +258,7 @@ def try_creating_permissions_sync_task(
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -282,13 +293,19 @@ def try_creating_permissions_sync_task(
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected try_creating_permissions_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"try_creating_permissions_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
return payload_id
|
||||
|
||||
|
||||
@@ -303,7 +320,7 @@ def try_creating_permissions_sync_task(
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
@@ -321,7 +338,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -378,7 +395,7 @@ def connector_permission_sync_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -388,6 +405,29 @@ def connector_permission_sync_generator_task(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
created = validate_ccpair_for_user(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
task_logger.warning(
|
||||
f"Unable to create connector credential pair for id: {cc_pair_id}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
@@ -439,6 +479,10 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
@@ -465,7 +509,7 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
@@ -473,6 +517,8 @@ def update_external_document_permissions_task(
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
@@ -480,7 +526,8 @@ def update_external_document_permissions_task(
|
||||
external_access = document_external_access.external_access
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
@@ -511,18 +558,33 @@ def update_external_document_permissions_task(
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
|
||||
except Exception:
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"Exception in update_external_document_permissions_task: "
|
||||
f"update_external_document_permissions_task exceptioned: "
|
||||
f"connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
|
||||
)
|
||||
|
||||
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
return False
|
||||
|
||||
task_logger.info(
|
||||
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -569,7 +631,7 @@ def validate_permission_sync_fences(
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
@@ -779,7 +841,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
|
||||
@@ -37,9 +37,12 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -55,6 +58,7 @@ from onyx.redis.redis_connector_ext_group_sync import (
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -119,11 +123,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -140,7 +144,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
try:
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
@@ -148,7 +152,10 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session, source, only_sync=True
|
||||
db_session,
|
||||
source,
|
||||
access_type=AccessType.SYNC,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
@@ -195,12 +202,17 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected check_for_external_group_sync exception: tenant={tenant_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
task_logger.info(f"check_for_external_group_sync finished: tenant={tenant_id}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -208,7 +220,7 @@ def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -230,7 +242,7 @@ def try_creating_external_group_sync_task(
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -267,12 +279,19 @@ def try_creating_external_group_sync_task(
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected try_creating_external_group_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"try_creating_external_group_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
return payload_id
|
||||
|
||||
|
||||
@@ -287,7 +306,7 @@ def try_creating_external_group_sync_task(
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
External group sync task for a given connector credential pair
|
||||
@@ -296,7 +315,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -357,16 +376,40 @@ def connector_external_group_sync_generator_task(
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
created = validate_ccpair_for_user(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
task_logger.warning(
|
||||
f"Unable to create connector credential pair for id: {cc_pair_id}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
@@ -378,12 +421,23 @@ def connector_external_group_sync_generator_task(
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise e
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
logger.debug(f"New external user groups: {external_user_groups}")
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
@@ -404,11 +458,19 @@ def connector_external_group_sync_generator_task(
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -431,7 +493,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
@@ -459,12 +521,11 @@ def validate_external_group_sync_fences(
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
|
||||
@@ -41,16 +41,18 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -90,6 +92,9 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SPAWN_FAILED = "spawn_failed" # connector spawn failed
|
||||
SPAWN_NOT_ALIVE = (
|
||||
"spawn_not_alive" # spawn succeeded but process did not come alive
|
||||
)
|
||||
|
||||
BLOCKED_BY_DELETION = "blocked_by_deletion"
|
||||
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
|
||||
@@ -103,6 +108,9 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
"index_attempt_mismatch" # expected index attempt metadata not found in db
|
||||
)
|
||||
|
||||
CONNECTOR_VALIDATION_ERROR = (
|
||||
"connector_validation_error" # the connector validation failed
|
||||
)
|
||||
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
|
||||
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
|
||||
|
||||
@@ -112,6 +120,8 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
# the watchdog terminated the task due to no activity
|
||||
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
|
||||
|
||||
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
|
||||
# consolidate once we know more
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
|
||||
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
|
||||
@@ -121,6 +131,7 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
|
||||
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
|
||||
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
|
||||
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
|
||||
@@ -137,6 +148,8 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
|
||||
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
|
||||
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
|
||||
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
|
||||
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
|
||||
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
|
||||
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
|
||||
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
|
||||
@@ -169,7 +182,7 @@ class SimpleJobResult:
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str | None
|
||||
tenant_id: str
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
@@ -197,7 +210,7 @@ class ConnectorIndexingLogBuilder:
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
@@ -345,15 +358,16 @@ def monitor_ccpair_indexing_taskset(
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
redis_client = get_redis_client()
|
||||
redis_client_replica = get_redis_replica_client()
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
@@ -391,7 +405,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# 1/3: KICKOFF
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
@@ -412,7 +426,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# gather cc_pair_ids
|
||||
lock_beat.reacquire()
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
@@ -422,7 +436,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
for search_settings_instance in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
@@ -500,7 +514,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
@@ -552,7 +566,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
monitor_ccpair_indexing_taskset(
|
||||
tenant_id, key_bytes, redis_client_replica, db_session
|
||||
)
|
||||
@@ -583,8 +597,8 @@ def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -635,7 +649,7 @@ def connector_indexing_task(
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
raise SimpleJobException(
|
||||
@@ -729,7 +743,7 @@ def connector_indexing_task(
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
@@ -764,9 +778,9 @@ def connector_indexing_task(
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
redis_connector_index,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -788,6 +802,15 @@ def connector_indexing_task(
|
||||
# get back the total number of indexed docs and return it
|
||||
n_final_progress = redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
except ConnectorValidationError:
|
||||
raise SimpleJobException(
|
||||
f"Indexing task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
@@ -795,8 +818,8 @@ def connector_indexing_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
@@ -867,7 +890,7 @@ def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
@@ -876,6 +899,9 @@ def connector_indexing_proxy_task(
|
||||
|
||||
TODO(rkuo): refactor this so that there is a single return path where we canonically
|
||||
log the result of running this function.
|
||||
|
||||
NOTE: we try/except all db access in this function because as a watchdog, this function
|
||||
needs to be extremely stable.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
@@ -901,18 +927,18 @@ def connector_indexing_proxy_task(
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
task_logger.info(f"submitting connector_indexing_task with tenant_id={tenant_id}")
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
tenant_id,
|
||||
global_version.is_ee_version(),
|
||||
pure=False,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
if not job:
|
||||
if not job or not job.process:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
@@ -923,13 +949,39 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
return
|
||||
|
||||
task_logger.info(log_builder.build("Indexing watchdog - spawn succeeded"))
|
||||
# Ensure the process has moved out of the starting state
|
||||
num_waits = 0
|
||||
while True:
|
||||
if num_waits > 15:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
job.release()
|
||||
return
|
||||
|
||||
if job.process.is_alive() or job.process.exitcode is not None:
|
||||
break
|
||||
|
||||
sleep(1)
|
||||
num_waits += 1
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawn succeeded",
|
||||
pid=str(job.process.pid),
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
@@ -940,6 +992,9 @@ def connector_indexing_proxy_task(
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
|
||||
redis_connector_index.set_active() # renew active signal
|
||||
redis_connector_index.set_connector_active() # prime the connective active signal
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
@@ -965,7 +1020,7 @@ def connector_indexing_proxy_task(
|
||||
job.release()
|
||||
break
|
||||
|
||||
# if a termination signal is detected, clean up and break
|
||||
# if a termination signal is detected, break (exit point will clean up)
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
log_builder.build("Indexing watchdog - termination signal detected")
|
||||
@@ -974,10 +1029,24 @@ def connector_indexing_proxy_task(
|
||||
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
|
||||
break
|
||||
|
||||
# if activity timeout is detected, break (exit point will clean up)
|
||||
if not redis_connector_index.connector_active():
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - activity timeout exceeded",
|
||||
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
)
|
||||
|
||||
result.status = (
|
||||
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
|
||||
)
|
||||
break
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
@@ -987,25 +1056,29 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception looking up index attempt"
|
||||
)
|
||||
)
|
||||
continue
|
||||
except Exception:
|
||||
|
||||
except Exception as e:
|
||||
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
|
||||
result.exception_str = traceback.format_exc()
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# No need to expose full stack trace for validation errors
|
||||
result.exception_str = str(e)
|
||||
else:
|
||||
result.exception_str = traceback.format_exc()
|
||||
|
||||
# handle exit and reporting
|
||||
elapsed = time.monotonic() - start
|
||||
if result.exception_str is not None:
|
||||
# print with exception
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
@@ -1045,15 +1118,13 @@ def connector_indexing_proxy_task(
|
||||
# print without exception
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled"
|
||||
@@ -1061,6 +1132,25 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Indexing watchdog - activity timeout exceeded: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
job.cancel()
|
||||
else:
|
||||
pass
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
@@ -1080,7 +1170,7 @@ def connector_indexing_proxy_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
@@ -1095,7 +1185,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
|
||||
try:
|
||||
locked = True
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
|
||||
for attempt in old_attempts:
|
||||
task_logger.info(
|
||||
@@ -1131,5 +1221,5 @@ def cleanup_checkpoint_task(
|
||||
self: Task, *, index_attempt_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cleanup_checkpoint(db_session, index_attempt_id)
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
@@ -93,27 +93,25 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
redis_connector: RedisConnector,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_tag: str = f"{self.__class__.__name__}.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
@@ -127,8 +125,8 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
# with daemon=True. It seems likely some indexing tasks will need to spawn other processes
|
||||
# eventually, which daemon=True prevents, so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
@@ -143,8 +141,6 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_connector.prune.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
@@ -156,7 +152,7 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"{self.__class__.__name__} - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
@@ -167,13 +163,31 @@ class IndexingCallback(IndexingHeartbeatInterface):
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
|
||||
class IndexingCallback(IndexingCallbackBase):
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
):
|
||||
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
|
||||
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
self.redis_connector_index.set_active()
|
||||
self.redis_connector_index.set_connector_active()
|
||||
super().progress(tag, amount)
|
||||
self.redis_client.incrby(
|
||||
self.redis_connector_index.generator_progress_key, amount
|
||||
)
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
@@ -297,7 +311,7 @@ def validate_indexing_fence(
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
@@ -318,7 +332,7 @@ def validate_indexing_fences(
|
||||
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
@@ -428,7 +442,7 @@ def try_creating_indexing_task(
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -8,7 +8,7 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
|
||||
|
||||
@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
@@ -75,7 +75,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | N
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
|
||||
@@ -26,7 +26,8 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -42,7 +43,6 @@ from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
|
||||
|
||||
@@ -91,7 +91,7 @@ class Metric(BaseModel):
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
def emit(self, tenant_id: str) -> None:
|
||||
# Convert value to appropriate type based on the input value
|
||||
bool_value = None
|
||||
float_value = None
|
||||
@@ -656,7 +656,7 @@ def build_job_id(
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
@@ -668,7 +668,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
task_logger.info("Starting background monitoring")
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
lock_monitoring: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
|
||||
@@ -683,7 +683,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client(tenant_id=tenant_id)
|
||||
redis_std = get_redis_client()
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
@@ -693,7 +693,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
]
|
||||
|
||||
# Collect and log each metric
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
@@ -771,12 +771,11 @@ def cloud_check_alembic() -> bool | None:
|
||||
if tenant_id is None:
|
||||
continue
|
||||
|
||||
with get_session_with_tenant(tenant_id=None) as session:
|
||||
with get_session_with_shared_schema() as session:
|
||||
try:
|
||||
result = session.execute(
|
||||
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
|
||||
)
|
||||
|
||||
result_scalar: str | None = result.scalar_one_or_none()
|
||||
if result_scalar is None:
|
||||
raise ValueError("Alembic version should not be None.")
|
||||
@@ -865,7 +864,7 @@ def cloud_monitor_celery_queues(
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PostgresAdvisoryLocks
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_tenant
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
@@ -36,7 +36,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -41,7 +41,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -55,6 +55,7 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -62,6 +63,12 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class PruneCallback(IndexingCallbackBase):
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
self.redis_connector.prune.set_active()
|
||||
super().progress(tag, amount)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
@@ -107,9 +114,9 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -127,14 +134,14 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -182,18 +189,20 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(f"Unexpected pruning check exception: {error_msg}")
|
||||
task_logger.exception("Unexpected exception during pruning check")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
task_logger.info(f"check_for_pruning finished: tenant={tenant_id}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -202,7 +211,7 @@ def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
@@ -295,13 +304,19 @@ def try_creating_prune_generator_task(
|
||||
redis_connector.prune.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected try_creating_prune_generator_task exception: cc_pair={cc_pair.id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"try_creating_prune_generator_task finished: cc_pair={cc_pair.id} payload_id={payload_id}"
|
||||
)
|
||||
return payload_id
|
||||
|
||||
|
||||
@@ -318,7 +333,7 @@ def connector_pruning_generator_task(
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -337,7 +352,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -395,7 +410,7 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
@@ -425,6 +440,7 @@ def connector_pruning_generator_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
@@ -434,12 +450,11 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
redis_connector.new_index(search_settings.id)
|
||||
|
||||
callback = IndexingCallback(
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
@@ -506,7 +521,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -552,7 +567,7 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -600,7 +615,7 @@ def validate_pruning_fences(
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
|
||||
@@ -32,7 +32,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self.index.delete_single(
|
||||
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
@@ -27,7 +28,7 @@ from onyx.db.document import mark_document_as_modified
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@@ -45,6 +46,24 @@ LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
class OnyxCeleryTaskCompletionStatus(str, Enum):
|
||||
"""The different statuses the watchdog can finish with.
|
||||
|
||||
TODO: create broader success/failure/abort categories
|
||||
"""
|
||||
|
||||
UNDEFINED = "undefined"
|
||||
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SKIPPED = "skipped"
|
||||
|
||||
SOFT_TIME_LIMIT = "soft_time_limit"
|
||||
|
||||
NON_RETRYABLE_EXCEPTION = "non_retryable_exception"
|
||||
RETRYABLE_EXCEPTION = "retryable_exception"
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
@@ -57,7 +76,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
@@ -78,8 +97,10 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
@@ -105,10 +126,14 @@ def document_by_cc_pair_cleanup_task(
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
elif count > 1:
|
||||
action = "update"
|
||||
|
||||
@@ -152,10 +177,11 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
else:
|
||||
pass
|
||||
db_session.commit()
|
||||
|
||||
db_session.commit()
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
else:
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
@@ -167,57 +193,79 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
|
||||
except Exception as ex:
|
||||
e: Exception | None = None
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
while True:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
break
|
||||
|
||||
task_logger.exception(
|
||||
f"document_by_cc_pair_cleanup_task exceptioned: doc={document_id}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
|
||||
if (
|
||||
self.max_retries is not None
|
||||
and self.request.retries >= self.max_retries
|
||||
):
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"doc={document_id}"
|
||||
)
|
||||
return False
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
# in vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
break
|
||||
|
||||
task_logger.exception(f"Unexpected exception: doc={document_id}")
|
||||
|
||||
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
|
||||
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
else:
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
# in vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
|
||||
break # we won't hit this, but it looks weird not to have it
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"document_by_cc_pair_cleanup_task completed: status={completion_status.value} doc={document_id}"
|
||||
)
|
||||
|
||||
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
return False
|
||||
|
||||
task_logger.info(f"document_by_cc_pair_cleanup_task finished: doc={document_id}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -249,7 +297,8 @@ def cloud_beat_task_generator(
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] | list[None] = []
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -277,6 +326,8 @@ def cloud_beat_task_generator(
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -296,6 +347,7 @@ def cloud_beat_task_generator(
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -34,7 +35,7 @@ from onyx.db.document_set import fetch_document_sets
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import DocumentSet
|
||||
@@ -75,7 +76,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
@@ -84,8 +85,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
@@ -98,7 +99,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
try:
|
||||
# 1/3: KICKOFF
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -106,7 +107,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
# region document set scan
|
||||
lock_beat.reacquire()
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
@@ -117,7 +118,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -136,7 +137,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
pass
|
||||
else:
|
||||
usergroup_ids: list[int] = []
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
@@ -146,7 +147,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -167,7 +168,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
@@ -177,7 +178,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -207,7 +208,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
|
||||
@@ -283,7 +284,7 @@ def try_generate_document_set_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -360,7 +361,7 @@ def try_generate_user_group_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -447,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -522,13 +523,13 @@ def monitor_document_set_taskset(
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
@@ -540,75 +541,103 @@ def vespa_metadata_sync_task(
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
return False
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=no_operation "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
|
||||
else:
|
||||
# document set sync
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
# document set sync
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
|
||||
except Exception as ex:
|
||||
e: Exception | None = None
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
while True:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
break
|
||||
|
||||
task_logger.exception(
|
||||
f"vespa_metadata_sync_task exceptioned: doc={document_id}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
|
||||
if (
|
||||
self.max_retries is not None
|
||||
and self.request.retries >= self.max_retries
|
||||
):
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception during vespa metadata sync: doc={document_id}"
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
|
||||
break # we won't hit this, but it looks weird not to have it
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"vespa_metadata_sync_task completed: status={completion_status.value} doc={document_id}"
|
||||
)
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from onyx.db.background_error import create_background_error
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
|
||||
|
||||
def emit_background_error(
|
||||
@@ -9,5 +11,10 @@ def emit_background_error(
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
except IntegrityError as e:
|
||||
# Log an error if the cc_pair_id was deleted or any other exception occurs
|
||||
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
|
||||
create_background_error(db_session, error_message, None)
|
||||
|
||||
@@ -16,7 +16,10 @@ from typing import Optional
|
||||
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -54,6 +57,15 @@ def _initializer(
|
||||
kwargs = {}
|
||||
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
# 1. Get tenant_id from args or fallback to default
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
for arg in reversed(args):
|
||||
if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX):
|
||||
tenant_id = arg
|
||||
break
|
||||
|
||||
# 2. Set the tenant context before running anything
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Reset the engine in the child process
|
||||
SqlEngine.reset_engine()
|
||||
@@ -81,6 +93,8 @@ def _initializer(
|
||||
queue.put(error_msg) # Send the exception to the parent process
|
||||
|
||||
sys.exit(255) # use 255 to indicate a generic exception
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _run_in_process(
|
||||
|
||||
@@ -15,11 +15,13 @@ from onyx.background.indexing.memory_tracer import MemoryTracer
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -28,7 +30,7 @@ from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -53,6 +55,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -65,7 +68,6 @@ def _get_connector_runner(
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
@@ -84,8 +86,12 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
if not INTEGRATION_TESTS_MODE:
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
||||
@@ -234,7 +240,7 @@ def _check_failure_threshold(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -244,7 +250,7 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
@@ -370,7 +376,7 @@ def _run_indexing(
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
@@ -381,7 +387,6 @@ def _run_indexing(
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
@@ -430,7 +435,7 @@ def _run_indexing(
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp, ctx, index_attempt_id
|
||||
@@ -439,7 +444,7 @@ def _run_indexing(
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
@@ -503,7 +508,7 @@ def _run_indexing(
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
@@ -516,7 +521,7 @@ def _run_indexing(
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
@@ -533,7 +538,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
@@ -555,7 +560,7 @@ def _run_indexing(
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
@@ -567,9 +572,29 @@ def _run_indexing(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
)
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -587,7 +612,7 @@ def _run_indexing(
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -609,7 +634,7 @@ def _run_indexing(
|
||||
memory_tracer.stop()
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
@@ -654,7 +679,7 @@ def _run_indexing(
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
@@ -669,12 +694,12 @@ def run_indexing_entrypoint(
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
if MULTI_TENANT:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
@@ -690,7 +715,7 @@ def run_indexing_entrypoint(
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -27,8 +27,10 @@ from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -80,6 +82,26 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
rerank_settings = search_request.rerank_settings
|
||||
|
||||
using_cloud_reranking = (
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
# remove SearchQuery entirely)
|
||||
if (
|
||||
force_use_tool.force_use
|
||||
and search_tool
|
||||
and force_use_tool.args
|
||||
and force_use_tool.tool_name == search_tool.name
|
||||
and QUERY_FIELD in force_use_tool.args
|
||||
):
|
||||
search_request.query = force_use_tool.args[QUERY_FIELD]
|
||||
|
||||
self.graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
prompt_builder=prompt_builder,
|
||||
@@ -94,7 +116,6 @@ class Answer:
|
||||
force_use_tool=force_use_tool,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
assert db_session, "db_session must be provided for agentic persistence"
|
||||
self.graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -104,6 +125,7 @@ class Answer:
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
allow_refinement=True,
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
|
||||
@@ -190,7 +190,8 @@ def create_chat_chain(
|
||||
and previous_message.message_type == MessageType.ASSISTANT
|
||||
and mainline_messages
|
||||
):
|
||||
mainline_messages[-1] = current_message
|
||||
if current_message.refined_answer_improvement:
|
||||
mainline_messages[-1] = current_message
|
||||
else:
|
||||
mainline_messages.append(current_message)
|
||||
|
||||
|
||||
@@ -142,6 +142,15 @@ class MessageResponseIDInfo(BaseModel):
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class AgentMessageIDInfo(BaseModel):
|
||||
level: int
|
||||
message_id: int
|
||||
|
||||
|
||||
class AgenticMessageResponseIDInfo(BaseModel):
|
||||
agentic_message_ids: list[AgentMessageIDInfo]
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
|
||||
@@ -7,10 +7,12 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.models import AgenticMessageResponseIDInfo
|
||||
from onyx.chat.models import AgentMessageIDInfo
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerPostInfo
|
||||
@@ -143,9 +145,10 @@ from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
def _translate_citations(
|
||||
@@ -307,6 +310,7 @@ ChatPacket = (
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| AgenticMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| AgentSearchPacket
|
||||
)
|
||||
@@ -342,7 +346,7 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
tenant_id = get_current_tenant_id()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
@@ -631,6 +635,7 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
@@ -742,14 +747,13 @@ def stream_chat_message_objects(
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
raw_user_uploaded_files=latest_query_files or [],
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
@@ -865,7 +869,6 @@ def stream_chat_message_objects(
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
@@ -1015,7 +1018,7 @@ def stream_chat_message_objects(
|
||||
if info.message_specific_citations
|
||||
else None
|
||||
),
|
||||
error=None,
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
|
||||
@@ -1033,6 +1036,7 @@ def stream_chat_message_objects(
|
||||
next_level = 1
|
||||
prev_message = gen_ai_response_message
|
||||
agent_answers = answer.llm_answer_by_level()
|
||||
agentic_message_ids = []
|
||||
while next_level in agent_answers:
|
||||
next_answer = agent_answers[next_level]
|
||||
info = info_by_subq[
|
||||
@@ -1053,7 +1057,12 @@ def stream_chat_message_objects(
|
||||
citations=info.message_specific_citations.citation_map
|
||||
if info.message_specific_citations
|
||||
else None,
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
is_agentic=True,
|
||||
)
|
||||
agentic_message_ids.append(
|
||||
AgentMessageIDInfo(level=next_level, message_id=next_answer_message.id)
|
||||
)
|
||||
next_level += 1
|
||||
prev_message = next_answer_message
|
||||
@@ -1061,11 +1070,9 @@ def stream_chat_message_objects(
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
|
||||
|
||||
yield msg_detail_response
|
||||
yield translate_db_message_to_chat_message_detail(gen_ai_response_message)
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(error_msg)
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_toke
|
||||
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_message_tokens
|
||||
@@ -19,6 +20,7 @@ from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
@@ -31,8 +33,16 @@ from onyx.tools.tool import Tool
|
||||
|
||||
def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
# for o-series markdown generation
|
||||
if (
|
||||
llm_config.model_provider == OPENAI_PROVIDER_NAME
|
||||
and llm_config.model_name.startswith("o")
|
||||
):
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
@@ -110,21 +120,8 @@ class AnswerPromptBuilder:
|
||||
),
|
||||
)
|
||||
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
|
||||
(
|
||||
system_message,
|
||||
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
if system_message
|
||||
else None
|
||||
)
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
user_message,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
self.update_system_prompt(system_message)
|
||||
self.update_user_prompt(user_message)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
|
||||
@@ -31,22 +31,9 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
|
||||
AGENT_ANSWER_GENERATION_BY_FAST_LLM = (
|
||||
os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true"
|
||||
)
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
|
||||
@@ -178,80 +165,172 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
|
||||
) # 25
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
) # 3
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
|
||||
) # 30
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
|
||||
) # 8
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
) # 12
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
|
||||
) # 25
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
) # 25
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
) # 8
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
) # 6
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
|
||||
) # 1
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
|
||||
) # 4
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
) # 8
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
) # 8
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
@@ -29,6 +30,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
@@ -626,6 +630,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
|
||||
|
||||
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
|
||||
@@ -98,9 +98,18 @@ CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
|
||||
# hard timeout applied by the watchdog to the indexing connector run
|
||||
# to handle hung connectors
|
||||
CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
|
||||
|
||||
# soft timeout for the lock taken by the indexing connector run
|
||||
# allows the lock to eventually expire if the managing code around it dies
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
# CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 15 minutes
|
||||
# hard termination should always fire first if the connector is hung
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
|
||||
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
@@ -204,6 +213,12 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class QueryHistoryType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ANONYMIZED = "anonymized"
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
@@ -333,6 +348,9 @@ class OnyxRedisSignals:
|
||||
BLOCK_PRUNING = "signal:block_pruning"
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
|
||||
"signal:block_validate_connector_deletion_fences"
|
||||
)
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
|
||||
@@ -7,11 +7,18 @@ from typing import Optional
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.client import Config # type: ignore
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from botocore.exceptions import PartialCredentialsError
|
||||
from mypy_boto3_s3 import S3Client # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import BlobType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -240,6 +247,73 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Blob storage credentials not loaded."
|
||||
)
|
||||
|
||||
if not self.bucket_name:
|
||||
raise ConnectorValidationError(
|
||||
"No bucket name was provided in connector settings."
|
||||
)
|
||||
|
||||
try:
|
||||
# We only fetch one object/page as a light-weight validation step.
|
||||
# This ensures we trigger typical S3 permission checks (ListObjectsV2, etc.).
|
||||
self.s3_client.list_objects_v2(
|
||||
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
|
||||
)
|
||||
|
||||
except NoCredentialsError:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"No valid blob storage credentials found or provided to boto3."
|
||||
)
|
||||
except PartialCredentialsError:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Partial or incomplete blob storage credentials provided to boto3."
|
||||
)
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"].get("Code", "")
|
||||
status_code = e.response["ResponseMetadata"].get("HTTPStatusCode")
|
||||
|
||||
# Most common S3 error cases
|
||||
if error_code in [
|
||||
"AccessDenied",
|
||||
"InvalidAccessKeyId",
|
||||
"SignatureDoesNotMatch",
|
||||
]:
|
||||
if status_code == 403 or error_code == "AccessDenied":
|
||||
raise InsufficientPermissionsError(
|
||||
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
|
||||
"Please check your bucket policy and/or IAM policy."
|
||||
)
|
||||
if status_code == 401 or error_code == "SignatureDoesNotMatch":
|
||||
raise CredentialExpiredError(
|
||||
"Provided blob storage credentials appear invalid or expired."
|
||||
)
|
||||
|
||||
raise CredentialExpiredError(
|
||||
f"Credential issue encountered ({error_code})."
|
||||
)
|
||||
|
||||
if error_code == "NoSuchBucket" or status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
|
||||
)
|
||||
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Catch-all for anything not captured by the above
|
||||
# Since we are unsure of the error and it may not disable the connector,
|
||||
# raise an unexpected error (does not disable connector)
|
||||
raise UnexpectedError(
|
||||
f"Unexpected error during blob storage settings validation: {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
credentials_dict = {
|
||||
|
||||
@@ -5,6 +5,8 @@ import requests
|
||||
|
||||
class BookStackClientRequestFailedError(ConnectionError):
|
||||
def __init__(self, status: int, error: str) -> None:
|
||||
self.status_code = status
|
||||
self.error = error
|
||||
super().__init__(
|
||||
"BookStack Client request failed with status {status}: {error}".format(
|
||||
status=status, error=error
|
||||
|
||||
@@ -7,7 +7,11 @@ from typing import Any
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.bookstack.client import BookStackApiClient
|
||||
from onyx.connectors.bookstack.client import BookStackClientRequestFailedError
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -214,3 +218,39 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
break
|
||||
else:
|
||||
time.sleep(0.2)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Validate that the BookStack credentials and connector settings are correct.
|
||||
Specifically checks that we can make an authenticated request to BookStack.
|
||||
"""
|
||||
if not self.bookstack_client:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"BookStack credentials have not been loaded."
|
||||
)
|
||||
|
||||
try:
|
||||
# Attempt to fetch a small batch of books (arbitrary endpoint) to verify credentials
|
||||
_ = self.bookstack_client.get(
|
||||
"/books", params={"count": "1", "offset": "0"}
|
||||
)
|
||||
|
||||
except BookStackClientRequestFailedError as e:
|
||||
# Check for HTTP status codes
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Your BookStack credentials appear to be invalid or expired (HTTP 401)."
|
||||
) from e
|
||||
elif e.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"The configured BookStack token does not have sufficient permissions (HTTP 403)."
|
||||
) from e
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected BookStack error (status={e.status_code}): {e}"
|
||||
) from e
|
||||
|
||||
except Exception as exc:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected error while validating BookStack connector settings: {exc}"
|
||||
) from exc
|
||||
|
||||
@@ -4,6 +4,8 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
@@ -16,6 +18,10 @@ from onyx.connectors.confluence.utils import build_confluence_document_id
|
||||
from onyx.connectors.confluence.utils import datetime_from_string
|
||||
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -397,3 +403,33 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence credentials not loaded.")
|
||||
|
||||
try:
|
||||
spaces = self._confluence_client.get_all_spaces(limit=1)
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code if e.response else None
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Confluence credentials (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Insufficient permissions to access Confluence resources (HTTP 403)."
|
||||
)
|
||||
raise UnexpectedError(
|
||||
f"Unexpected Confluence error (status={status_code}): {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise UnexpectedError(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
if not spaces or not spaces.get("results"):
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
"there truly are no spaces in this Confluence instance."
|
||||
)
|
||||
|
||||
@@ -8,8 +8,12 @@ from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -29,6 +33,16 @@ class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConfluenceUser(BaseModel):
|
||||
user_id: str # accountId in Cloud, userKey in Server
|
||||
username: str | None # Confluence Cloud doesn't give usernames
|
||||
display_name: str
|
||||
# Confluence Data Center doesn't give email back by default,
|
||||
# have to fetch it with a different endpoint
|
||||
email: str | None
|
||||
type: str
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
@@ -149,7 +163,7 @@ class OnyxConfluence(Confluence):
|
||||
)
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
@@ -224,9 +238,41 @@ class OnyxConfluence(Confluence):
|
||||
raise e
|
||||
|
||||
# yield the results individually
|
||||
yield from next_response.get("results", [])
|
||||
results = cast(list[dict[str, Any]], next_response.get("results", []))
|
||||
yield from results
|
||||
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
old_url_suffix = url_suffix
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
|
||||
# make sure we don't update the start by more than the amount
|
||||
# of results we were able to retrieve. The Confluence API has a
|
||||
# weird behavior where if you pass in a limit that is too large for
|
||||
# the configured server, it will artificially limit the amount of
|
||||
# results returned BUT will not apply this to the start parameter.
|
||||
# This will cause us to miss results.
|
||||
if url_suffix and "start" in url_suffix:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.warning(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
adjusted_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "start", str(adjusted_start)
|
||||
)
|
||||
|
||||
# some APIs don't properly paginate, so we need to manually update the `start` param
|
||||
if auto_paginate and len(results) > 0:
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
updated_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
old_url_suffix, "start", str(updated_start)
|
||||
)
|
||||
|
||||
def paginated_cql_retrieval(
|
||||
self,
|
||||
@@ -275,21 +321,97 @@ class OnyxConfluence(Confluence):
|
||||
self,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
) -> Iterator[ConfluenceUser]:
|
||||
"""
|
||||
The search/user endpoint can be used to fetch users.
|
||||
It's a seperate endpoint from the content/search endpoint used only for users.
|
||||
Otherwise it's very similar to the content/search endpoint.
|
||||
"""
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user" if self.cloud else "rest/api/search"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
yield from self._paginate_url(url, limit)
|
||||
if self.cloud:
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
# endpoint doesn't properly paginate, so we need to manually update the `start` param
|
||||
# thus the auto_paginate flag
|
||||
for user_result in self._paginate_url(url, limit, auto_paginate=True):
|
||||
# Example response:
|
||||
# {
|
||||
# 'user': {
|
||||
# 'type': 'known',
|
||||
# 'accountId': '712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'accountType': 'atlassian',
|
||||
# 'email': 'chris@danswer.ai',
|
||||
# 'publicName': 'Chris Weaver',
|
||||
# 'profilePicture': {
|
||||
# 'path': '/wiki/aa-avatar/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'width': 48,
|
||||
# 'height': 48,
|
||||
# 'isDefault': False
|
||||
# },
|
||||
# 'displayName': 'Chris Weaver',
|
||||
# 'isExternalCollaborator': False,
|
||||
# '_expandable': {
|
||||
# 'operations': '',
|
||||
# 'personalSpace': ''
|
||||
# },
|
||||
# '_links': {
|
||||
# 'self': 'https://danswerai.atlassian.net/wiki/rest/api/user?accountId=712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d'
|
||||
# }
|
||||
# },
|
||||
# 'title': 'Chris Weaver',
|
||||
# 'excerpt': '',
|
||||
# 'url': '/people/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'breadcrumbs': [],
|
||||
# 'entityType': 'user',
|
||||
# 'iconCssClass': 'aui-icon content-type-profile',
|
||||
# 'lastModified': '2025-02-18T04:08:03.579Z',
|
||||
# 'score': 0.0
|
||||
# }
|
||||
user = user_result["user"]
|
||||
yield ConfluenceUser(
|
||||
user_id=user["accountId"],
|
||||
username=None,
|
||||
display_name=user["displayName"],
|
||||
email=user.get("email"),
|
||||
type=user["accountType"],
|
||||
)
|
||||
else:
|
||||
# https://developer.atlassian.com/server/confluence/rest/v900/api-group-user/#api-rest-api-user-list-get
|
||||
# ^ is only available on data center deployments
|
||||
# Example response:
|
||||
# [
|
||||
# {
|
||||
# 'type': 'known',
|
||||
# 'username': 'admin',
|
||||
# 'userKey': '40281082950c5fe901950c61c55d0000',
|
||||
# 'profilePicture': {
|
||||
# 'path': '/images/icons/profilepics/default.svg',
|
||||
# 'width': 48,
|
||||
# 'height': 48,
|
||||
# 'isDefault': True
|
||||
# },
|
||||
# 'displayName': 'Admin Test',
|
||||
# '_links': {
|
||||
# 'self': 'http://localhost:8090/rest/api/user?key=40281082950c5fe901950c61c55d0000'
|
||||
# },
|
||||
# '_expandable': {
|
||||
# 'status': ''
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
for user in self._paginate_url("rest/api/user/list", limit):
|
||||
yield ConfluenceUser(
|
||||
user_id=user["userKey"],
|
||||
username=user["username"],
|
||||
display_name=user["displayName"],
|
||||
email=None,
|
||||
type=user.get("type", "user"),
|
||||
)
|
||||
|
||||
def paginated_groups_by_user_retrieval(
|
||||
self,
|
||||
user: dict[str, Any],
|
||||
user_id: str, # accountId in Cloud, userKey in Server
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
@@ -297,7 +419,7 @@ class OnyxConfluence(Confluence):
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
user_field = "accountId" if self.cloud else "key"
|
||||
user_value = user["accountId"] if self.cloud else user["userKey"]
|
||||
user_value = user_id
|
||||
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
|
||||
user_query = f"{user_field}={quote(user_value)}"
|
||||
|
||||
@@ -423,11 +545,15 @@ def build_confluence_client(
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> OnyxConfluence:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
try:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(str(e))
|
||||
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
|
||||
@@ -2,7 +2,10 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -10,13 +13,13 @@ from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
OnyxConfluence,
|
||||
)
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -24,7 +27,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
@@ -47,7 +50,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
@@ -78,7 +81,7 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
@@ -191,7 +194,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
@@ -279,3 +282,32 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
return parse_qs(parsed_url.query).get(param, [None])[0]
|
||||
|
||||
|
||||
def get_start_param_from_url(url: str) -> int:
|
||||
"""Get the start parameter from a url"""
|
||||
start_str = get_single_param_from_url(url, "start")
|
||||
if start_str is None:
|
||||
return 0
|
||||
return int(start_str)
|
||||
|
||||
|
||||
def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
"""Update a parameter in a path. Path should look something like:
|
||||
|
||||
/api/rest/users?start=0&limit=10
|
||||
"""
|
||||
parsed_url = urlparse(path)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params[param] = [value]
|
||||
return (
|
||||
path.split("?")[0]
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
@@ -24,16 +25,22 @@ def datetime_to_utc(dt: datetime) -> datetime:
|
||||
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
# Remove all timezone abbreviations in parentheses
|
||||
datetime_str = re.sub(r"\([A-Z]+\)", "", datetime_str).strip()
|
||||
|
||||
# Remove any remaining parentheses and their contents
|
||||
datetime_str = re.sub(r"\(.*?\)", "", datetime_str).strip()
|
||||
|
||||
try:
|
||||
dt = parse(datetime_str)
|
||||
except ValueError:
|
||||
# Handle malformed timezone by attempting to fix common format issues
|
||||
# Fix common format issues (e.g. "0000" => "+0000")
|
||||
if "0000" in datetime_str:
|
||||
# Convert "0000" to "+0000" for proper timezone parsing
|
||||
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(fixed_dt_str)
|
||||
datetime_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(datetime_str)
|
||||
else:
|
||||
raise
|
||||
|
||||
return datetime_to_utc(dt)
|
||||
|
||||
|
||||
|
||||
@@ -4,11 +4,15 @@ from typing import Any
|
||||
|
||||
from dropbox import Dropbox # type: ignore
|
||||
from dropbox.exceptions import ApiError # type:ignore
|
||||
from dropbox.exceptions import AuthError # type:ignore
|
||||
from dropbox.files import FileMetadata # type:ignore
|
||||
from dropbox.files import FolderMetadata # type:ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialInvalidError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -141,6 +145,29 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox credentials not loaded.")
|
||||
|
||||
try:
|
||||
self.dropbox_client.files_list_folder(path="", limit=1)
|
||||
except AuthError as e:
|
||||
logger.exception("Failed to validate Dropbox credentials")
|
||||
raise CredentialInvalidError(f"Dropbox credential is invalid: {e.error}")
|
||||
except ApiError as e:
|
||||
if (
|
||||
e.error is not None
|
||||
and "insufficient_permissions" in str(e.error).lower()
|
||||
):
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Dropbox token does not have sufficient permissions."
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected Dropbox error during validation: {e.user_message_text or e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Unexpected error during Dropbox settings validation: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
49
backend/onyx/connectors/exceptions.py
Normal file
49
backend/onyx/connectors/exceptions.py
Normal file
@@ -0,0 +1,49 @@
|
||||
class ValidationError(Exception):
|
||||
"""General exception for validation errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ConnectorValidationError(ValidationError):
|
||||
"""General exception for connector validation errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class UnexpectedError(ValidationError):
|
||||
"""Raised when an unexpected error occurs during connector validation.
|
||||
|
||||
Unexpected errors don't necessarily mean the credential is invalid,
|
||||
but rather that there was an error during the validation process
|
||||
or we encountered a currently unhandled error case.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Unexpected error during connector validation"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CredentialInvalidError(ConnectorValidationError):
|
||||
"""Raised when a connector's credential is invalid."""
|
||||
|
||||
def __init__(self, message: str = "Credential is invalid"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CredentialExpiredError(ConnectorValidationError):
|
||||
"""Raised when a connector's credential is expired."""
|
||||
|
||||
def __init__(self, message: str = "Credential has expired"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InsufficientPermissionsError(ConnectorValidationError):
|
||||
"""Raised when the credential does not have sufficient API permissions."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = "Insufficient permissions for the requested operation"
|
||||
):
|
||||
super().__init__(message)
|
||||
@@ -3,8 +3,8 @@ from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -17,6 +17,7 @@ from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
from onyx.connectors.egnyte.connector import EgnyteConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.fireflies.connector import FirefliesConnector
|
||||
from onyx.connectors.freshdesk.connector import FreshdeskConnector
|
||||
@@ -52,7 +53,9 @@ from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
from onyx.connectors.xenforo.connector import XenforoConnector
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from onyx.connectors.zulip.connector import ZulipConnector
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.models import Credential
|
||||
|
||||
|
||||
@@ -160,13 +163,9 @@ def instantiate_connector(
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
tenant_id: str | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
if source in DocumentSourceRequiringTenantContext:
|
||||
connector_specific_config["tenant_id"] = tenant_id
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
@@ -174,3 +173,51 @@ def instantiate_connector(
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def validate_ccpair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
enforce_creation: bool = True,
|
||||
) -> bool:
|
||||
if INTEGRATION_TESTS_MODE:
|
||||
return True
|
||||
|
||||
# Validate the connector settings
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
db_session,
|
||||
)
|
||||
|
||||
if not connector:
|
||||
raise ValueError("Connector not found")
|
||||
|
||||
if (
|
||||
connector.source == DocumentSource.INGESTION_API
|
||||
or connector.source == DocumentSource.MOCK_CONNECTOR
|
||||
):
|
||||
return True
|
||||
|
||||
if not credential:
|
||||
raise ValueError("Credential not found")
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=connector.source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
if enforce_creation:
|
||||
raise ConnectorValidationError(str(e))
|
||||
else:
|
||||
return False
|
||||
|
||||
runnable_connector.validate_connector_settings()
|
||||
return True
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -27,8 +27,6 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -165,12 +163,10 @@ class LocalFileConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.tenant_id = tenant_id
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -179,9 +175,8 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
@@ -203,8 +198,6 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
||||
|
||||
@@ -187,12 +187,12 @@ class FirefliesConnector(PollConnector, LoadConnector):
|
||||
return self._process_transcripts()
|
||||
|
||||
def poll_source(
|
||||
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(
|
||||
start_unixtime, tz=timezone.utc
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
|
||||
|
||||
@@ -229,16 +229,20 @@ class GitbookConnector(LoadConnector, PollConnector):
|
||||
|
||||
try:
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content")
|
||||
pages = content.get("pages", [])
|
||||
|
||||
pages: list[dict[str, Any]] = content.get("pages", [])
|
||||
current_batch: list[Document] = []
|
||||
for page in pages:
|
||||
updated_at = datetime.fromisoformat(page["updatedAt"])
|
||||
|
||||
while pages:
|
||||
page = pages.pop(0)
|
||||
|
||||
updated_at_raw = page.get("updatedAt")
|
||||
if updated_at_raw is None:
|
||||
# if updatedAt is not present, that means the page has never been edited
|
||||
continue
|
||||
|
||||
updated_at = datetime.fromisoformat(updated_at_raw)
|
||||
if start and updated_at < start:
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
return
|
||||
continue
|
||||
if end and updated_at > end:
|
||||
continue
|
||||
|
||||
@@ -250,6 +254,8 @@ class GitbookConnector(LoadConnector, PollConnector):
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
pages.extend(page.get("pages", []))
|
||||
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import cast
|
||||
from github import Github
|
||||
from github import RateLimitExceededException
|
||||
from github import Repository
|
||||
from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
@@ -16,6 +17,10 @@ from github.PullRequest import PullRequest
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -26,7 +31,6 @@ from onyx.connectors.models import Section
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -120,7 +124,7 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repo_name: str,
|
||||
repo_name: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
@@ -158,53 +162,81 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def _get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to get organization first
|
||||
try:
|
||||
org = github_client.get_organization(self.repo_owner)
|
||||
return list(org.get_repos())
|
||||
except GithubException:
|
||||
# If not an org, try as a user
|
||||
user = github_client.get_user(self.repo_owner)
|
||||
return list(user.get_repos())
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _fetch_from_github(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repo = self._get_github_repo(self.github_client)
|
||||
repos = (
|
||||
[self._get_github_repo(self.github_client)]
|
||||
if self.repo_name
|
||||
else self._get_all_repos(self.github_client)
|
||||
)
|
||||
|
||||
if self.include_prs:
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
return
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
|
||||
if self.include_issues:
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
if self.include_issues:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
return
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_github()
|
||||
@@ -226,6 +258,63 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
|
||||
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
|
||||
|
||||
if not self.repo_owner:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: 'repo_owner' must be provided."
|
||||
)
|
||||
|
||||
try:
|
||||
if self.repo_name:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
else:
|
||||
# Try to get organization first
|
||||
try:
|
||||
org = self.github_client.get_organization(self.repo_owner)
|
||||
org.get_repos().totalCount # Just check if we can access repos
|
||||
except GithubException:
|
||||
# If not an org, try as a user
|
||||
user = self.github_client.get_user(self.repo_owner)
|
||||
user.get_repos().totalCount # Just check if we can access repos
|
||||
|
||||
except RateLimitExceededException:
|
||||
raise UnexpectedError(
|
||||
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
|
||||
except GithubException as e:
|
||||
if e.status == 401:
|
||||
raise CredentialExpiredError(
|
||||
"GitHub credential appears to be invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
||||
)
|
||||
elif e.status == 404:
|
||||
if self.repo_name:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub user or organization not found: {self.repo_owner}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
||||
)
|
||||
except Exception as exc:
|
||||
raise Exception(
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -297,6 +297,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
@@ -304,6 +305,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
# full_threads is an iterator containing a single thread
|
||||
# so we need to convert it to a list and grab the first element
|
||||
@@ -335,6 +337,7 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
doc_batch.append(
|
||||
SlimDocument(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user