mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
14 Commits
fix_openap
...
readme-tip
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
487052db5d | ||
|
|
4e5ee3007d | ||
|
|
ae3608f1d6 | ||
|
|
0f1211401f | ||
|
|
5ac7ef1c21 | ||
|
|
667723af3f | ||
|
|
1580335c77 | ||
|
|
ee75fe3502 | ||
|
|
329c1591f0 | ||
|
|
1430de6bff | ||
|
|
73256ff827 | ||
|
|
f059b03ef3 | ||
|
|
6c10513165 | ||
|
|
5a610d1f48 |
5
.github/workflows/pr-integration-tests.yml
vendored
5
.github/workflows/pr-integration-tests.yml
vendored
@@ -145,7 +145,7 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
@@ -157,7 +157,6 @@ jobs:
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
INTEGRATION_TESTS_MODE=true \
|
||||
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
|
||||
id: start_docker
|
||||
|
||||
@@ -200,7 +199,7 @@ jobs:
|
||||
cd backend/tests/integration/mock_services
|
||||
docker compose -f docker-compose.mock-it-services.yml \
|
||||
-p mock-it-services-stack up -d
|
||||
|
||||
|
||||
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
|
||||
@@ -74,9 +74,7 @@ jobs:
|
||||
python -m pip install --upgrade pip
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
97
.github/workflows/pr-python-model-tests.yml
vendored
97
.github/workflows/pr-python-model-tests.yml
vendored
@@ -1,29 +1,18 @@
|
||||
name: Model Server Tests
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
schedule:
|
||||
# This cron expression runs the job daily at 16:00 UTC (9am PT)
|
||||
- cron: "0 16 * * *"
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
branch:
|
||||
description: 'Branch to run the workflow on'
|
||||
required: false
|
||||
default: 'main'
|
||||
|
||||
|
||||
env:
|
||||
# Bedrock
|
||||
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
|
||||
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
|
||||
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
|
||||
|
||||
# API keys for testing
|
||||
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
|
||||
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
|
||||
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
|
||||
# OpenAI
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
|
||||
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
|
||||
|
||||
jobs:
|
||||
model-check:
|
||||
@@ -37,23 +26,6 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# tag every docker image with "test" so that we can spin up the correct set
|
||||
# of images during testing
|
||||
|
||||
# We don't need to build the Web Docker image since it's not yet used
|
||||
# in the integration tests. We have a separate action to verify that it builds
|
||||
# successfully.
|
||||
- name: Pull Model Server Docker image
|
||||
run: |
|
||||
docker pull onyxdotapp/onyx-model-server:latest
|
||||
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
|
||||
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
@@ -69,49 +41,6 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
|
||||
- name: Start Docker containers
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
AUTH_TYPE=basic \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
DISABLE_TELEMETRY=true \
|
||||
IMAGE_TAG=test \
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
|
||||
id: start_docker
|
||||
|
||||
- name: Wait for service to be ready
|
||||
run: |
|
||||
echo "Starting wait-for-service script..."
|
||||
|
||||
start_time=$(date +%s)
|
||||
timeout=300 # 5 minutes in seconds
|
||||
|
||||
while true; do
|
||||
current_time=$(date +%s)
|
||||
elapsed_time=$((current_time - start_time))
|
||||
|
||||
if [ $elapsed_time -ge $timeout ]; then
|
||||
echo "Timeout reached. Service did not become ready in 5 minutes."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Use curl with error handling to ignore specific exit code 56
|
||||
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
|
||||
|
||||
if [ "$response" = "200" ]; then
|
||||
echo "Service is ready!"
|
||||
break
|
||||
elif [ "$response" = "curl_error" ]; then
|
||||
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
|
||||
else
|
||||
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
|
||||
fi
|
||||
|
||||
sleep 5
|
||||
done
|
||||
echo "Finished waiting for service."
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: |
|
||||
@@ -127,23 +56,3 @@ jobs:
|
||||
-H 'Content-type: application/json' \
|
||||
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
|
||||
$SLACK_WEBHOOK
|
||||
|
||||
- name: Dump all-container logs (optional)
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
|
||||
|
||||
- name: Upload logs
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: docker-all-logs
|
||||
path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
- name: Stop Docker containers
|
||||
if: always()
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v
|
||||
|
||||
|
||||
26
README.md
26
README.md
@@ -26,12 +26,12 @@
|
||||
|
||||
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
|
||||
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
|
||||
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
|
||||
There are over 40 supported connectors such as Google Drive, Slack, Confluence, Salesforce, etc. which keep knowledge and permissions up to date.
|
||||
Create custom AI agents with unique prompts, knowledge, and actions the agents can take.
|
||||
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
|
||||
|
||||
|
||||
<h3>Feature Highlights</h3>
|
||||
<h3>Feature Showcase</h3>
|
||||
|
||||
**Deep research over your team's knowledge:**
|
||||
|
||||
@@ -54,7 +54,8 @@ https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95
|
||||
|
||||
|
||||
## Deployment
|
||||
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
> [!TIP]
|
||||
> To try it out for free and get started in seconds, check out **[Onyx Cloud](https://cloud.onyx.app/signup)**.
|
||||
|
||||
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
|
||||
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
|
||||
@@ -63,21 +64,22 @@ We also have built-in support for high-availability/scalable deployment on Kuber
|
||||
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
|
||||
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🚧 Roadmap
|
||||
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Extensions to the Chrome Plugin
|
||||
- Latest methods in information retrieval (StructRAG, LightGraphRAG, etc.)
|
||||
- Personalized Search
|
||||
- Organizational understanding and ability to locate and suggest experts from your team.
|
||||
- Code Search
|
||||
- SQL and Structured Query Language
|
||||
|
||||
|
||||
## 🔍 Other Notable Benefits of Onyx
|
||||
- Custom deep learning models only through Onyx + learn from user feedback.
|
||||
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
|
||||
- Knowledge curation features like document-sets, query history, usage analytics, etc.
|
||||
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
|
||||
|
||||
|
||||
## 🔌 Connectors
|
||||
Keep knowledge and access up to sync across 40+ connectors:
|
||||
|
||||
|
||||
@@ -28,11 +28,11 @@ RUN apt-get update && \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30 \
|
||||
libblkid1 \
|
||||
libmount1 \
|
||||
libsmartcols1 \
|
||||
libuuid1 \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc \
|
||||
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add indexes to document__tag
|
||||
|
||||
Revision ID: 1a03d2c2856b
|
||||
Revises: 9c00a2bccb83
|
||||
Create Date: 2025-02-18 10:45:13.957807
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1a03d2c2856b"
|
||||
down_revision = "9c00a2bccb83"
|
||||
branch_labels: None = None
|
||||
depends_on: None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
op.f("ix_document__tag_tag_id"),
|
||||
"document__tag",
|
||||
["tag_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")
|
||||
@@ -1,31 +0,0 @@
|
||||
"""add index
|
||||
|
||||
Revision ID: 8f43500ee275
|
||||
Revises: da42808081e3
|
||||
Create Date: 2025-02-24 17:35:33.072714
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "8f43500ee275"
|
||||
down_revision = "da42808081e3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX idx_chat_message_message_lower
|
||||
ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the index
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -1,43 +0,0 @@
|
||||
"""chat_message_agentic
|
||||
|
||||
Revision ID: 9c00a2bccb83
|
||||
Revises: b7a7eee5aa15
|
||||
Create Date: 2025-02-17 11:15:43.081150
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9c00a2bccb83"
|
||||
down_revision = "b7a7eee5aa15"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First add the column as nullable
|
||||
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
|
||||
|
||||
# Update existing rows based on presence of SubQuestions
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET is_agentic = EXISTS (
|
||||
SELECT 1
|
||||
FROM agent__sub_question
|
||||
WHERE agent__sub_question.primary_question_id = chat_message.id
|
||||
)
|
||||
WHERE is_agentic IS NULL
|
||||
"""
|
||||
)
|
||||
|
||||
# Make the column non-nullable with a default value of False
|
||||
op.alter_column(
|
||||
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "is_agentic")
|
||||
@@ -1,29 +0,0 @@
|
||||
"""remove inactive ccpair status on downgrade
|
||||
|
||||
Revision ID: acaab4ef4507
|
||||
Revises: b388730a2899
|
||||
Create Date: 2025-02-16 18:21:41.330212
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from sqlalchemy import update
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "acaab4ef4507"
|
||||
down_revision = "b388730a2899"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
update(ConnectorCredentialPair)
|
||||
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
|
||||
.values(status=ConnectorCredentialPairStatus.ACTIVE)
|
||||
)
|
||||
@@ -1,31 +0,0 @@
|
||||
"""nullable preferences
|
||||
|
||||
Revision ID: b388730a2899
|
||||
Revises: 1a03d2c2856b
|
||||
Create Date: 2025-02-17 18:49:22.643902
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b388730a2899"
|
||||
down_revision = "1a03d2c2856b"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=True)
|
||||
op.alter_column("user", "auto_scroll", nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Ensure no null values before making columns non-nullable
|
||||
op.execute(
|
||||
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
|
||||
)
|
||||
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
|
||||
|
||||
op.alter_column("user", "temperature_override_enabled", nullable=False)
|
||||
op.alter_column("user", "auto_scroll", nullable=False)
|
||||
@@ -1,120 +0,0 @@
|
||||
"""migrate jira connectors to new format
|
||||
|
||||
Revision ID: da42808081e3
|
||||
Revises: f13db29f3101
|
||||
Create Date: 2025-02-24 11:24:54.396040
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "da42808081e3"
|
||||
down_revision = "f13db29f3101"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all Jira connectors
|
||||
jira_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = :source
|
||||
"""
|
||||
),
|
||||
{"source": DocumentSource.JIRA.value.upper()},
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
for connector_id, old_config in jira_connectors:
|
||||
if not old_config:
|
||||
continue
|
||||
|
||||
# Extract project key from URL if it exists
|
||||
new_config: dict[str, str | None] = {}
|
||||
if project_url := old_config.get("jira_project_url"):
|
||||
# Parse the URL to get base and project
|
||||
try:
|
||||
jira_base, project_key = extract_jira_project(project_url)
|
||||
new_config = {"jira_base_url": jira_base, "project_key": project_key}
|
||||
except ValueError:
|
||||
# If URL parsing fails, just use the URL as the base
|
||||
new_config = {
|
||||
"jira_base_url": project_url.split("/projects/")[0],
|
||||
"project_key": None,
|
||||
}
|
||||
else:
|
||||
# For connectors without a project URL, we need admin intervention
|
||||
# Mark these for review
|
||||
print(
|
||||
f"WARNING: Jira connector {connector_id} has no project URL configured"
|
||||
)
|
||||
continue
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all Jira connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all Jira connectors
|
||||
jira_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = :source
|
||||
"""
|
||||
),
|
||||
{"source": DocumentSource.JIRA.value.upper()},
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config back to the old format
|
||||
for connector_id, new_config in jira_connectors:
|
||||
if not new_config:
|
||||
continue
|
||||
|
||||
old_config = {}
|
||||
base_url = new_config.get("jira_base_url")
|
||||
project_key = new_config.get("project_key")
|
||||
|
||||
if base_url and project_key:
|
||||
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
|
||||
elif base_url:
|
||||
old_config = {"jira_project_url": base_url}
|
||||
else:
|
||||
continue
|
||||
|
||||
# Update the connector config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :old_config
|
||||
WHERE id = :id
|
||||
"""
|
||||
),
|
||||
{"id": connector_id, "old_config": old_config},
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
"""Add composite index for last_modified and last_synced to document
|
||||
|
||||
Revision ID: f13db29f3101
|
||||
Revises: b388730a2899
|
||||
Create Date: 2025-02-18 22:48:11.511389
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f13db29f3101"
|
||||
down_revision = "acaab4ef4507"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | None = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_index(
|
||||
"ix_document_sync_status",
|
||||
"document",
|
||||
["last_modified", "last_synced"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_document_sync_status", table_name="document")
|
||||
@@ -21,7 +21,7 @@ logger = setup_logger()
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
@@ -62,7 +62,7 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
@@ -36,11 +35,10 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
|
||||
def get_cc_pairs_by_source(
|
||||
db_session: Session,
|
||||
source_type: DocumentSource,
|
||||
access_type: AccessType | None = None,
|
||||
status: ConnectorCredentialPairStatus | None = None,
|
||||
only_sync: bool,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""
|
||||
Get all cc_pairs for a given source type with optional filtering by access_type and status
|
||||
Get all cc_pairs for a given source type (and optionally only sync)
|
||||
result is sorted by cc_pair id
|
||||
"""
|
||||
query = (
|
||||
@@ -50,11 +48,8 @@ def get_cc_pairs_by_source(
|
||||
.order_by(ConnectorCredentialPair.id)
|
||||
)
|
||||
|
||||
if access_type is not None:
|
||||
query = query.filter(ConnectorCredentialPair.access_type == access_type)
|
||||
|
||||
if status is not None:
|
||||
query = query.filter(ConnectorCredentialPair.status == status)
|
||||
if only_sync:
|
||||
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
|
||||
|
||||
cc_pairs = query.all()
|
||||
return cc_pairs
|
||||
|
||||
@@ -14,24 +14,30 @@ def _build_group_member_email_map(
|
||||
confluence_client: OnyxConfluence, cc_pair_id: int
|
||||
) -> dict[str, set[str]]:
|
||||
group_member_emails: dict[str, set[str]] = {}
|
||||
for user in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user}")
|
||||
for user_result in confluence_client.paginated_cql_user_retrieval():
|
||||
logger.debug(f"Processing groups for user: {user_result}")
|
||||
|
||||
email = user.email
|
||||
user = user_result.get("user", {})
|
||||
if not user:
|
||||
msg = f"user result missing user field: {user_result}"
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
logger.error(msg)
|
||||
continue
|
||||
|
||||
email = user.get("email")
|
||||
if not email:
|
||||
# This field is only present in Confluence Server
|
||||
user_name = user.username
|
||||
user_name = user.get("username")
|
||||
# If it is present, try to get the email using a Server-specific method
|
||||
if user_name:
|
||||
email = get_user_email_from_username__server(
|
||||
confluence_client=confluence_client,
|
||||
user_name=user_name,
|
||||
)
|
||||
|
||||
if not email:
|
||||
# If we still don't have an email, skip this user
|
||||
msg = f"user result missing email field: {user}"
|
||||
if user.type == "app":
|
||||
msg = f"user result missing email field: {user_result}"
|
||||
if user.get("type") == "app":
|
||||
logger.warning(msg)
|
||||
else:
|
||||
emit_background_error(msg, cc_pair_id=cc_pair_id)
|
||||
@@ -39,7 +45,7 @@ def _build_group_member_email_map(
|
||||
continue
|
||||
|
||||
all_users_groups: set[str] = set()
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
|
||||
for group in confluence_client.paginated_groups_by_user_retrieval(user):
|
||||
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
|
||||
group_id = group["name"]
|
||||
group_member_emails.setdefault(group_id, set()).add(email)
|
||||
|
||||
@@ -62,14 +62,12 @@ def _fetch_permissions_for_permission_ids(
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
)
|
||||
|
||||
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
|
||||
fetched_permissions = execute_paginated_retrieval(
|
||||
retrieval_function=drive_service.permissions().list,
|
||||
list_key="permissions",
|
||||
fileId=doc_id,
|
||||
fields="permissions(id, emailAddress, type, domain)",
|
||||
supportsAllDrives=True,
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
|
||||
permissions_for_doc_id = []
|
||||
@@ -106,13 +104,7 @@ def _get_permissions_from_slim_doc(
|
||||
user_emails: set[str] = set()
|
||||
group_emails: set[str] = set()
|
||||
public = False
|
||||
skipped_permissions = 0
|
||||
|
||||
for permission in permissions_list:
|
||||
if not permission:
|
||||
skipped_permissions += 1
|
||||
continue
|
||||
|
||||
permission_type = permission["type"]
|
||||
if permission_type == "user":
|
||||
user_emails.add(permission["emailAddress"])
|
||||
@@ -129,11 +121,6 @@ def _get_permissions_from_slim_doc(
|
||||
elif permission_type == "anyone":
|
||||
public = True
|
||||
|
||||
if skipped_permissions > 0:
|
||||
logger.warning(
|
||||
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
|
||||
)
|
||||
|
||||
drive_id = permission_info.get("drive_id")
|
||||
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
|
||||
return await call_next(request)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in tenant ID middleware: {str(e)}")
|
||||
logger.error(f"Error in tenant ID middleware: {str(e)}")
|
||||
raise
|
||||
|
||||
|
||||
@@ -49,7 +49,7 @@ async def _get_tenant_id_from_request(
|
||||
"""
|
||||
# Check for API key
|
||||
tenant_id = extract_tenant_from_api_key_header(request)
|
||||
if tenant_id is not None:
|
||||
if tenant_id:
|
||||
return tenant_id
|
||||
|
||||
# Check for anonymous user cookie
|
||||
|
||||
@@ -36,12 +36,12 @@ from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -271,12 +271,12 @@ def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
@@ -329,6 +329,7 @@ def handle_slack_oauth_callback(
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
@@ -336,7 +337,7 @@ def handle_slack_oauth_callback(
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
@@ -522,6 +523,7 @@ def handle_google_drive_oauth_callback(
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
@@ -529,7 +531,7 @@ def handle_google_drive_oauth_callback(
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
|
||||
@@ -13,7 +13,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
from onyx.db.models import TokenRateLimit
|
||||
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def _check_token_rate_limits(user: User | None) -> None:
|
||||
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
|
||||
if user is None:
|
||||
# Unauthenticated users are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
|
||||
elif is_api_key_email_address(user.email):
|
||||
# API keys are only rate limited by global settings
|
||||
_user_is_rate_limited_by_global()
|
||||
_user_is_rate_limited_by_global(tenant_id)
|
||||
|
||||
else:
|
||||
run_functions_tuples_in_parallel(
|
||||
[
|
||||
(_user_is_rate_limited, (user.id,)),
|
||||
(_user_is_rate_limited_by_group, (user.id,)),
|
||||
(_user_is_rate_limited_by_global, ()),
|
||||
(_user_is_rate_limited, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
|
||||
(_user_is_rate_limited_by_global, (tenant_id,)),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -52,8 +52,8 @@ User rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user_rate_limits = fetch_all_user_token_rate_limits(
|
||||
db_session=db_session, enabled_only=True, ordered=False
|
||||
)
|
||||
@@ -93,8 +93,8 @@ User Group rate limits
|
||||
"""
|
||||
|
||||
|
||||
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
|
||||
|
||||
if group_rate_limits:
|
||||
|
||||
@@ -41,15 +41,14 @@ from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
@@ -58,14 +57,13 @@ router = APIRouter(prefix="/tenants")
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
@@ -74,15 +72,15 @@ async def get_anonymous_user_path_api(
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
@@ -103,7 +101,7 @@ async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
@@ -152,17 +150,14 @@ async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
|
||||
try:
|
||||
# Fetch tenant_id and current tenant's information
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
@@ -186,8 +181,6 @@ async def create_subscription_session(
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
@@ -204,7 +197,7 @@ async def impersonate_user(
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
with get_session_with_tenant(tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
@@ -228,9 +221,8 @@ async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str = Depends(get_current_tenant_id),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
|
||||
@@ -118,7 +118,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
@@ -134,7 +134,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
@@ -224,7 +224,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
name="Anthropic",
|
||||
provider=ANTHROPIC_PROVIDER_NAME,
|
||||
api_key=ANTHROPIC_DEFAULT_API_KEY,
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
default_model_name="claude-3-5-sonnet-20241022",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -98,17 +98,12 @@ class CloudEmbedding:
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Exception embedding text with OpenAI - {type(e)}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {self.provider} "
|
||||
f"Exception: {e}"
|
||||
f"Error embedding text with OpenAI: {str(e)} \n"
|
||||
f"Model: {model} \n"
|
||||
f"Provider: {self.provider} \n"
|
||||
f"Texts: {texts}"
|
||||
)
|
||||
logger.error(error_string)
|
||||
|
||||
# only log text when it's not an authentication error.
|
||||
if not isinstance(e, openai.AuthenticationError):
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
async def _embed_cohere(
|
||||
|
||||
@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="choose_tool",
|
||||
action=choose_tool,
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "call_tool"
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,14 +31,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -87,11 +85,9 @@ def check_sub_answer(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
|
||||
fast_llm.invoke,
|
||||
response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
@@ -100,7 +96,7 @@ def check_sub_answer(
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
@@ -46,13 +47,11 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -111,14 +110,15 @@ def generate_sub_answer(
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
def stream_sub_answer() -> list[str]:
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -142,15 +142,8 @@ def generate_sub_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -59,15 +60,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -80,7 +77,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
@@ -234,11 +230,7 @@ def generate_initial_answer(
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = trim_prompt_piece(
|
||||
@@ -268,16 +260,15 @@ def generate_initial_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -301,16 +292,9 @@ def generate_initial_answer(
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
streamed_tokens.append(content)
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -36,10 +36,7 @@ from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -50,7 +47,6 @@ from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -135,12 +131,10 @@ def decompose_orig_question(
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
@@ -160,7 +154,7 @@ def decompose_orig_question(
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError) as e:
|
||||
except LLMTimeoutError as e:
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
"""
|
||||
LangGraph edge to route to agent search.
|
||||
"""
|
||||
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
|
||||
):
|
||||
return "start_agent_search"
|
||||
else:
|
||||
return "call_tool"
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
# Choose the initial tool
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=choose_tool,
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
# Call the tool, if required
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
# Use the tool response
|
||||
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["call_tool", "start_agent_search", "logging_node"],
|
||||
["tool_call", "start_agent_search", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
|
||||
@@ -33,15 +33,13 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -107,14 +105,11 @@ def compare_answers(
|
||||
refined_answer_improvement: bool | None = None
|
||||
# no need to stream this
|
||||
try:
|
||||
resp = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
|
||||
model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
|
||||
resp = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -44,10 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -56,7 +53,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -138,17 +134,15 @@ def create_refined_sub_questions(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -22,17 +22,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -90,42 +84,30 @@ def extract_entities_terms(
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - extract entities terms")
|
||||
except ValueError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - extract entities terms")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -65,21 +66,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -98,7 +92,6 @@ from onyx.prompts.agent_search import (
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -260,12 +253,7 @@ def generate_validate_refined_answer(
|
||||
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
@@ -296,13 +284,13 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -327,15 +315,8 @@ def generate_validate_refined_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
return streamed_tokens
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
stream_refined_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
@@ -402,20 +383,16 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
validation_model = graph_config.tooling.fast_llm
|
||||
try:
|
||||
validation_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_response = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
separator=AGENT_ANSWER_SEPARATOR,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Timeout Error - validate refined answer")
|
||||
|
||||
|
||||
@@ -34,16 +34,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -71,7 +69,7 @@ def expand_queries(
|
||||
node_start_time = datetime.now()
|
||||
question = state.question
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
llm = graph_config.tooling.fast_llm
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_num = 0, 0
|
||||
@@ -90,12 +88,10 @@ def expand_queries(
|
||||
rewritten_queries = []
|
||||
|
||||
try:
|
||||
llm_response_list = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
@@ -105,7 +101,7 @@ def expand_queries(
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -55,7 +55,6 @@ def rerank_documents(
|
||||
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
|
||||
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
@@ -63,31 +62,23 @@ def rerank_documents(
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Initial default: no reranking. Will be overwritten below if reranking is warranted
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
if not allow_agent_reranking:
|
||||
logger.info("Use of local rerank model without GPU, skipping reranking")
|
||||
# No reranking, stay with verified_documents as default
|
||||
|
||||
else:
|
||||
# Reranking is warranted, use the rerank_sections functon
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
|
||||
@@ -25,15 +25,13 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,11 +86,8 @@ def verify_documents(
|
||||
] # default is to treat document as relevant
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
|
||||
response = fast_llm.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
@@ -101,7 +96,7 @@ def verify_documents(
|
||||
):
|
||||
verified_documents = []
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
|
||||
@@ -67,7 +67,6 @@ class GraphSearchConfig(BaseModel):
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def choose_tool(
|
||||
def llm_tool_choice(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
||||
write_custom_event("basic_response", packet, writer)
|
||||
|
||||
|
||||
def call_tool(
|
||||
def tool_call(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -43,9 +43,8 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -81,7 +80,6 @@ from onyx.tools.tool_implementations.search.search_tool import SearchResponseSum
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -397,13 +395,11 @@ def summarize_history(
|
||||
)
|
||||
|
||||
try:
|
||||
history_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
llm.invoke,
|
||||
history_response = llm.invoke(
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.configs.app_configs import SMTP_PORT
|
||||
from onyx.configs.app_configs import SMTP_SERVER
|
||||
from onyx.configs.app_configs import SMTP_USER
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
|
||||
@@ -188,51 +187,23 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
def send_user_email_invite(
|
||||
user_email: str, current_user: User, auth_type: AuthType
|
||||
) -> None:
|
||||
def send_user_email_invite(user_email: str, current_user: User) -> None:
|
||||
subject = "Invitation to Join Onyx Organization"
|
||||
heading = "You've Been Invited!"
|
||||
|
||||
# the exact action taken by the user, and thus the message, depends on the auth type
|
||||
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
if auth_type == AuthType.CLOUD:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.BASIC:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.GOOGLE_OAUTH:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to login with Google "
|
||||
"and complete your registration.</p>"
|
||||
)
|
||||
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
|
||||
message += (
|
||||
"<p>To join the organization, please click the button below to"
|
||||
" complete your registration.</p>"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid auth type: {auth_type}")
|
||||
|
||||
message = (
|
||||
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
|
||||
"<p>To join the organization, please click the button below to set a password "
|
||||
"or login with Google and complete your registration.</p>"
|
||||
)
|
||||
cta_text = "Join Organization"
|
||||
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
|
||||
html_content = build_html_email(heading, message, cta_text, cta_link)
|
||||
|
||||
# text content is the fallback for clients that don't support HTML
|
||||
# not as critical, so not having special cases for each auth type
|
||||
text_content = (
|
||||
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
|
||||
"To join the organization, please visit the following link:\n"
|
||||
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
|
||||
"You'll be asked to set a password or login with Google to complete your registration."
|
||||
)
|
||||
if auth_type == AuthType.CLOUD:
|
||||
text_content += "You'll be asked to set a password or login with Google to complete your registration."
|
||||
|
||||
send_email(user_email, subject, html_content, text_content)
|
||||
|
||||
|
||||
|
||||
@@ -42,5 +42,4 @@ def fetch_no_auth_user(
|
||||
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
|
||||
preferences=load_no_auth_user_preferences(store),
|
||||
is_anonymous_user=anonymous_user_enabled,
|
||||
password_configured=False,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
import random
|
||||
import secrets
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
@@ -88,6 +86,7 @@ from onyx.db.auth import get_user_db
|
||||
from onyx.db.auth import SQLAlchemyUserAdminDB
|
||||
from onyx.db.engine import get_async_session
|
||||
from onyx.db.engine import get_async_session_with_tenant
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
@@ -95,7 +94,6 @@ from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -105,11 +103,15 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
@@ -141,30 +143,6 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
return email or ""
|
||||
|
||||
|
||||
def generate_password() -> str:
|
||||
lowercase_letters = string.ascii_lowercase
|
||||
uppercase_letters = string.ascii_uppercase
|
||||
digits = string.digits
|
||||
special_characters = string.punctuation
|
||||
|
||||
# Ensure at least one of each required character type
|
||||
password = [
|
||||
secrets.choice(uppercase_letters),
|
||||
secrets.choice(digits),
|
||||
secrets.choice(special_characters),
|
||||
]
|
||||
|
||||
# Fill the rest with a mix of characters
|
||||
remaining_length = 12 - len(password)
|
||||
all_characters = lowercase_letters + uppercase_letters + digits + special_characters
|
||||
password.extend(secrets.choice(all_characters) for _ in range(remaining_length))
|
||||
|
||||
# Shuffle the password to randomize the position of the required characters
|
||||
random.shuffle(password)
|
||||
|
||||
return "".join(password)
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
@@ -215,7 +193,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
|
||||
@@ -617,39 +595,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
return user
|
||||
|
||||
async def reset_password_as_admin(self, user_id: uuid.UUID) -> str:
|
||||
"""Admin-only. Generate a random password for a user and return it."""
|
||||
user = await self.get(user_id)
|
||||
new_password = generate_password()
|
||||
await self._update(user, {"password": new_password})
|
||||
return new_password
|
||||
|
||||
async def change_password_if_old_matches(
|
||||
self, user: User, old_password: str, new_password: str
|
||||
) -> None:
|
||||
"""
|
||||
For normal users to change password if they know the old one.
|
||||
Raises 400 if old password doesn't match.
|
||||
"""
|
||||
verified, updated_password_hash = self.password_helper.verify_and_update(
|
||||
old_password, user.hashed_password
|
||||
)
|
||||
if not verified:
|
||||
# Raise some HTTPException (or your custom exception) if old password is invalid:
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid current password",
|
||||
)
|
||||
|
||||
# If the hash was upgraded behind the scenes, we can keep it before setting the new password:
|
||||
if updated_password_hash:
|
||||
user.hashed_password = updated_password_hash
|
||||
|
||||
# Now apply and validate the new password
|
||||
await self._update(user, {"password": new_password})
|
||||
|
||||
|
||||
async def get_user_manager(
|
||||
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
|
||||
@@ -874,9 +819,8 @@ async def current_limited_user(
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return await double_check_user(
|
||||
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
|
||||
)
|
||||
|
||||
@@ -33,7 +33,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import ColoredFormatter
|
||||
from onyx.utils.logger import PlainFormatter
|
||||
@@ -59,35 +58,13 @@ else:
|
||||
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
|
||||
|
||||
|
||||
class TenantAwareTask(Task):
|
||||
"""A custom base Task that sets tenant_id in a contextvar before running."""
|
||||
|
||||
abstract = True # So Celery knows not to register this as a real task.
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
# Grab tenant_id from the kwargs, or fallback to default if missing.
|
||||
tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Set the context var
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Actually run the task now
|
||||
try:
|
||||
return super().__call__(*args, **kwargs)
|
||||
finally:
|
||||
# Clear or reset after the task runs
|
||||
# so it does not leak into any subsequent tasks on the same worker process
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def on_task_prerun(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -224,7 +201,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
@@ -310,7 +287,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
WAIT_LIMIT = 60
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
time_start = time.monotonic()
|
||||
|
||||
logger.info("Waiting for primary worker to be ready...")
|
||||
@@ -462,6 +439,24 @@ class TenantContextFilter(logging.Filter):
|
||||
return True
|
||||
|
||||
|
||||
@task_prerun.connect
|
||||
def set_tenant_id(
|
||||
sender: Any | None = None,
|
||||
task_id: str | None = None,
|
||||
task: Task | None = None,
|
||||
args: tuple[Any, ...] | None = None,
|
||||
kwargs: dict[str, Any] | None = None,
|
||||
**other_kwargs: Any,
|
||||
) -> None:
|
||||
"""Signal handler to set tenant ID in context var before task starts."""
|
||||
tenant_id = (
|
||||
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
|
||||
if kwargs
|
||||
else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
|
||||
@task_postrun.connect
|
||||
def reset_tenant_id(
|
||||
sender: Any | None = None,
|
||||
|
||||
@@ -132,7 +132,6 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
f"Adding options to task {tenant_task_name}: {options}"
|
||||
)
|
||||
tenant_task["options"] = options
|
||||
|
||||
new_schedule[tenant_task_name] = tenant_task
|
||||
|
||||
return new_schedule
|
||||
@@ -257,4 +256,3 @@ def on_setup_logging(
|
||||
|
||||
|
||||
celery_app.conf.beat_scheduler = DynamicTenantScheduler
|
||||
celery_app.conf.task_default_base = app_base.TenantAwareTask
|
||||
|
||||
@@ -20,7 +20,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.heavy")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -21,7 +21,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.indexing")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -23,7 +23,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.light")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -20,7 +20,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_default_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
from onyx.db.index_attempt import mark_attempt_canceled
|
||||
@@ -38,7 +38,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -47,7 +47,6 @@ logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.primary")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
|
||||
|
||||
@signals.task_prerun.connect
|
||||
@@ -102,7 +101,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
# This is singleton work that should be done on startup exactly once
|
||||
# by the primary worker. This is unnecessary in the multi tenant scenario
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
|
||||
info: dict[str, Any] = cast(dict, r.info("replication"))
|
||||
@@ -159,7 +158,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
# mark orphaned index attempts as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_default_tenant() as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
|
||||
for attempt_id in unfenced_attempt_ids:
|
||||
attempt = get_index_attempt(db_session, attempt_id)
|
||||
@@ -235,7 +234,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
|
||||
r = get_shared_redis_client()
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
if lock.owned():
|
||||
task_logger.debug("Reacquiring primary worker lock.")
|
||||
|
||||
@@ -92,8 +92,7 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue and return them
|
||||
as a set.
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
|
||||
@@ -8,21 +8,16 @@ from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.connector_credential_pair import add_deletion_failure_message
|
||||
from onyx.db.connector_credential_pair import (
|
||||
@@ -32,7 +27,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -57,51 +52,6 @@ class TaskDependencyError(RuntimeError):
|
||||
with connector deletion."""
|
||||
|
||||
|
||||
def revoke_tasks_blocking_deletion(
|
||||
redis_connector: RedisConnector, db_session: Session, app: Celery
|
||||
) -> None:
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
try:
|
||||
index_payload = redis_connector_index.payload
|
||||
if index_payload and index_payload.celery_task_id:
|
||||
app.control.revoke(index_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked indexing task {index_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking indexing task")
|
||||
|
||||
try:
|
||||
permissions_sync_payload = redis_connector.permissions.payload
|
||||
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
|
||||
app.control.revoke(permissions_sync_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking pruning task")
|
||||
|
||||
try:
|
||||
prune_payload = redis_connector.prune.payload
|
||||
if prune_payload and prune_payload.celery_task_id:
|
||||
app.control.revoke(prune_payload.celery_task_id)
|
||||
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking permissions sync task")
|
||||
|
||||
try:
|
||||
external_group_sync_payload = redis_connector.external_group_sync.payload
|
||||
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
|
||||
app.control.revoke(external_group_sync_payload.celery_task_id)
|
||||
task_logger.info(
|
||||
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception("Exception while revoking external group sync task")
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
ignore_result=True,
|
||||
@@ -112,45 +62,29 @@ def revoke_tasks_blocking_deletion(
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
# Prevent this task from overlapping with itself
|
||||
# these tasks should never overlap
|
||||
if not lock_beat.acquire(blocking=False):
|
||||
return None
|
||||
|
||||
try:
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
|
||||
# clear fences that don't have associated celery tasks in progress
|
||||
try:
|
||||
validate_connector_deletion_fences(
|
||||
tenant_id, r, r_replica, r_celery, lock_beat
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating connector deletion fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
|
||||
|
||||
# collect cc_pair_ids
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair.id)
|
||||
|
||||
# try running cleanup on the cc_pair_ids
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
try:
|
||||
try_generate_document_cc_pair_cleanup_tasks(
|
||||
@@ -158,38 +92,9 @@ def check_for_connector_deletion_task(
|
||||
)
|
||||
except TaskDependencyError as e:
|
||||
# this means we wanted to start deleting but dependent tasks were running
|
||||
# on the first error, we set a stop signal and revoke the dependent tasks
|
||||
# on subsequent errors, we hard reset blocking fences after our specified timeout
|
||||
# is exceeded
|
||||
# Leave a stop signal to clear indexing and pruning tasks more quickly
|
||||
task_logger.info(str(e))
|
||||
|
||||
if not redis_connector.stop.fenced:
|
||||
# one time revoke of celery tasks
|
||||
task_logger.info("Revoking any tasks blocking deletion.")
|
||||
revoke_tasks_blocking_deletion(
|
||||
redis_connector, db_session, self.app
|
||||
)
|
||||
redis_connector.stop.set_fence(True)
|
||||
redis_connector.stop.set_timeout()
|
||||
else:
|
||||
# stop signal already set
|
||||
if redis_connector.stop.timed_out:
|
||||
# waiting too long, just reset blocking fences
|
||||
task_logger.info(
|
||||
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
|
||||
)
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
redis_connector.prune.reset()
|
||||
redis_connector.permissions.reset()
|
||||
redis_connector.external_group_sync.reset()
|
||||
else:
|
||||
# just wait
|
||||
pass
|
||||
redis_connector.stop.set_fence(True)
|
||||
else:
|
||||
# clear the stop signal if it exists ... no longer needed
|
||||
redis_connector.stop.set_fence(False)
|
||||
@@ -264,7 +169,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
return None
|
||||
|
||||
# set a basic fence to start
|
||||
redis_connector.delete.set_active()
|
||||
fence_payload = RedisConnectorDeletePayload(
|
||||
num_tasks=None,
|
||||
submitted=datetime.now(timezone.utc),
|
||||
@@ -373,7 +277,7 @@ def monitor_connector_deletion_taskset(
|
||||
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
|
||||
)
|
||||
if remaining > 0:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -383,7 +287,7 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
return
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -497,171 +401,3 @@ def monitor_connector_deletion_taskset(
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
|
||||
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
|
||||
)
|
||||
|
||||
# validate all existing connector deletion jobs
|
||||
lock_beat.reacquire()
|
||||
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
|
||||
for key in keys:
|
||||
key_bytes = cast(bytes, key)
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
validate_connector_deletion_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
r,
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.delete.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.delete.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.delete.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_connector_deletion_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
# we're active if there are still tasks to run and those tasks all exist in celery
|
||||
if tasks_scanned > 0 and tasks_not_in_celery == 0:
|
||||
redis_connector.delete.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.delete.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_connector_deletion_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.delete.reset()
|
||||
return
|
||||
|
||||
@@ -30,7 +30,6 @@ from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
@@ -43,12 +42,10 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -66,7 +63,6 @@ from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -123,13 +119,13 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# TODO(rkuo): merge into check function after lookup table for fences is added
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -144,7 +140,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
try:
|
||||
# get all cc pairs that need to be synced
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
@@ -193,23 +189,16 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_permissions_taskset(
|
||||
tenant_id, key_bytes, r, db_session
|
||||
)
|
||||
task_logger.info(f"check_for_doc_permissions_sync finished: tenant={tenant_id}")
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -258,7 +247,7 @@ def try_creating_permissions_sync_task(
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -293,19 +282,13 @@ def try_creating_permissions_sync_task(
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected try_creating_permissions_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"try_creating_permissions_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
return payload_id
|
||||
|
||||
|
||||
@@ -338,7 +321,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -395,7 +378,7 @@ def connector_permission_sync_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -405,30 +388,6 @@ def connector_permission_sync_generator_task(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
created = validate_ccpair_for_user(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
task_logger.warning(
|
||||
f"Unable to create connector credential pair for id: {cc_pair_id}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
@@ -480,10 +439,6 @@ def connector_permission_sync_generator_task(
|
||||
redis_connector.permissions.generator_complete = tasks_generated
|
||||
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
@@ -518,8 +473,6 @@ def update_external_document_permissions_task(
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
@@ -527,8 +480,7 @@ def update_external_document_permissions_task(
|
||||
external_access = document_external_access.external_access
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
@@ -559,28 +511,13 @@ def update_external_document_permissions_task(
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"update_external_document_permissions_task exceptioned: "
|
||||
f"Exception in update_external_document_permissions_task: "
|
||||
f"connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
|
||||
)
|
||||
|
||||
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
return False
|
||||
|
||||
task_logger.info(
|
||||
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -37,12 +37,9 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
@@ -58,7 +55,6 @@ from onyx.redis.redis_connector_ext_group_sync import (
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -126,8 +122,8 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -144,7 +140,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
|
||||
try:
|
||||
cc_pair_ids_to_sync: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
@@ -152,10 +148,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session,
|
||||
source,
|
||||
access_type=AccessType.SYNC,
|
||||
status=ConnectorCredentialPairStatus.ACTIVE,
|
||||
db_session, source, only_sync=True
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
@@ -202,17 +195,12 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected check_for_external_group_sync exception: tenant={tenant_id} {error_msg}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
|
||||
task_logger.info(f"check_for_external_group_sync finished: tenant={tenant_id}")
|
||||
return True
|
||||
|
||||
|
||||
@@ -242,7 +230,7 @@ def try_creating_external_group_sync_task(
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
insert_sync_record(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -279,19 +267,12 @@ def try_creating_external_group_sync_task(
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected try_creating_external_group_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info(
|
||||
f"try_creating_external_group_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
|
||||
)
|
||||
return payload_id
|
||||
|
||||
|
||||
@@ -315,7 +296,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -376,41 +357,16 @@ def connector_external_group_sync_generator_task(
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
eager_load_credential=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(
|
||||
f"No connector credential pair found for id: {cc_pair_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
created = validate_ccpair_for_user(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
task_logger.warning(
|
||||
f"Unable to create connector credential pair for id: {cc_pair_id}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
|
||||
)
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise
|
||||
|
||||
source_type = cc_pair.connector.source
|
||||
|
||||
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
|
||||
@@ -422,23 +378,12 @@ def connector_external_group_sync_generator_task(
|
||||
logger.info(
|
||||
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
|
||||
)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
raise e
|
||||
|
||||
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
|
||||
|
||||
logger.info(
|
||||
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
|
||||
)
|
||||
logger.debug(f"New external user groups: {external_user_groups}")
|
||||
|
||||
replace_user__ext_group_for_cc_pair(
|
||||
db_session=db_session,
|
||||
@@ -459,19 +404,11 @@ def connector_external_group_sync_generator_task(
|
||||
sync_status=SyncStatus.SUCCESS,
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id} {error_msg}"
|
||||
)
|
||||
task_logger.exception(
|
||||
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
)
|
||||
|
||||
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
|
||||
task_logger.exception(msg)
|
||||
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
@@ -522,6 +459,7 @@ def validate_external_group_sync_fences(
|
||||
)
|
||||
|
||||
lock_beat.reacquire()
|
||||
|
||||
return
|
||||
|
||||
|
||||
|
||||
@@ -41,18 +41,16 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.db.connector import mark_ccpair_with_indexing_trigger
|
||||
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -92,9 +90,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SPAWN_FAILED = "spawn_failed" # connector spawn failed
|
||||
SPAWN_NOT_ALIVE = (
|
||||
"spawn_not_alive" # spawn succeeded but process did not come alive
|
||||
)
|
||||
|
||||
BLOCKED_BY_DELETION = "blocked_by_deletion"
|
||||
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
|
||||
@@ -108,9 +103,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
"index_attempt_mismatch" # expected index attempt metadata not found in db
|
||||
)
|
||||
|
||||
CONNECTOR_VALIDATION_ERROR = (
|
||||
"connector_validation_error" # the connector validation failed
|
||||
)
|
||||
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
|
||||
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
|
||||
|
||||
@@ -120,8 +112,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
# the watchdog terminated the task due to no activity
|
||||
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
|
||||
|
||||
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
|
||||
# consolidate once we know more
|
||||
OUT_OF_MEMORY = "out_of_memory"
|
||||
|
||||
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
|
||||
@@ -131,7 +121,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
|
||||
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
|
||||
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
|
||||
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
|
||||
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
|
||||
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
|
||||
@@ -148,8 +137,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
|
||||
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
|
||||
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
|
||||
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
|
||||
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
|
||||
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
|
||||
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
|
||||
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
|
||||
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
|
||||
@@ -361,13 +348,12 @@ def monitor_ccpair_indexing_taskset(
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
tasks_created = 0
|
||||
locked = False
|
||||
redis_client = get_redis_client()
|
||||
redis_client_replica = get_redis_replica_client()
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
@@ -405,7 +391,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# 1/3: KICKOFF
|
||||
|
||||
# check for search settings swap
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_search_settings = check_index_swap(db_session=db_session)
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
# So that the first time users aren't surprised by really slow speed of first
|
||||
@@ -426,7 +412,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
# gather cc_pair_ids
|
||||
lock_beat.reacquire()
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = fetch_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
@@ -436,7 +422,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
search_settings_list = get_active_search_settings_list(db_session)
|
||||
for search_settings_instance in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
@@ -514,7 +500,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
# Fail any index attempts in the DB that don't have fences
|
||||
# This shouldn't ever happen!
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
|
||||
db_session, redis_client
|
||||
)
|
||||
@@ -566,7 +552,7 @@ def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_indexing_taskset(
|
||||
tenant_id, key_bytes, redis_client_replica, db_session
|
||||
)
|
||||
@@ -597,8 +583,8 @@ def connector_indexing_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str | None,
|
||||
is_ee: bool,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -649,7 +635,7 @@ def connector_indexing_task(
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
if redis_connector.delete.fenced:
|
||||
raise SimpleJobException(
|
||||
@@ -743,7 +729,7 @@ def connector_indexing_task(
|
||||
redis_connector_index.set_fence(payload)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if not attempt:
|
||||
raise SimpleJobException(
|
||||
@@ -778,9 +764,9 @@ def connector_indexing_task(
|
||||
callback = IndexingCallback(
|
||||
os.getppid(),
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
redis_connector_index,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
@@ -802,15 +788,6 @@ def connector_indexing_task(
|
||||
# get back the total number of indexed docs and return it
|
||||
n_final_progress = redis_connector_index.get_progress()
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
except ConnectorValidationError:
|
||||
raise SimpleJobException(
|
||||
f"Indexing task failed: attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}",
|
||||
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Indexing spawned task failed: attempt={index_attempt_id} "
|
||||
@@ -818,8 +795,8 @@ def connector_indexing_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
raise e
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
@@ -899,9 +876,6 @@ def connector_indexing_proxy_task(
|
||||
|
||||
TODO(rkuo): refactor this so that there is a single return path where we canonically
|
||||
log the result of running this function.
|
||||
|
||||
NOTE: we try/except all db access in this function because as a watchdog, this function
|
||||
needs to be extremely stable.
|
||||
"""
|
||||
start = time.monotonic()
|
||||
|
||||
@@ -927,18 +901,18 @@ def connector_indexing_proxy_task(
|
||||
task_logger.error("self.request.id is None!")
|
||||
|
||||
client = SimpleJobClient()
|
||||
task_logger.info(f"submitting connector_indexing_task with tenant_id={tenant_id}")
|
||||
|
||||
job = client.submit(
|
||||
connector_indexing_task,
|
||||
index_attempt_id,
|
||||
cc_pair_id,
|
||||
search_settings_id,
|
||||
global_version.is_ee_version(),
|
||||
tenant_id,
|
||||
global_version.is_ee_version(),
|
||||
pure=False,
|
||||
)
|
||||
|
||||
if not job or not job.process:
|
||||
if not job:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
@@ -949,39 +923,13 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
return
|
||||
|
||||
# Ensure the process has moved out of the starting state
|
||||
num_waits = 0
|
||||
while True:
|
||||
if num_waits > 15:
|
||||
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - finished",
|
||||
status=str(result.status.value),
|
||||
exit_code=str(result.exit_code),
|
||||
)
|
||||
)
|
||||
job.release()
|
||||
return
|
||||
|
||||
if job.process.is_alive() or job.process.exitcode is not None:
|
||||
break
|
||||
|
||||
sleep(1)
|
||||
num_waits += 1
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - spawn succeeded",
|
||||
pid=str(job.process.pid),
|
||||
)
|
||||
)
|
||||
task_logger.info(log_builder.build("Indexing watchdog - spawn succeeded"))
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
@@ -992,9 +940,6 @@ def connector_indexing_proxy_task(
|
||||
index_attempt.connector_credential_pair.connector.source.value
|
||||
)
|
||||
|
||||
redis_connector_index.set_active() # renew active signal
|
||||
redis_connector_index.set_connector_active() # prime the connective active signal
|
||||
|
||||
while True:
|
||||
sleep(5)
|
||||
|
||||
@@ -1020,7 +965,7 @@ def connector_indexing_proxy_task(
|
||||
job.release()
|
||||
break
|
||||
|
||||
# if a termination signal is detected, break (exit point will clean up)
|
||||
# if a termination signal is detected, clean up and break
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
log_builder.build("Indexing watchdog - termination signal detected")
|
||||
@@ -1029,24 +974,10 @@ def connector_indexing_proxy_task(
|
||||
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
|
||||
break
|
||||
|
||||
# if activity timeout is detected, break (exit point will clean up)
|
||||
if not redis_connector_index.connector_active():
|
||||
task_logger.warning(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - activity timeout exceeded",
|
||||
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
)
|
||||
|
||||
result.status = (
|
||||
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
|
||||
)
|
||||
break
|
||||
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
@@ -1056,29 +987,25 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception looking up index attempt"
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# No need to expose full stack trace for validation errors
|
||||
result.exception_str = str(e)
|
||||
else:
|
||||
result.exception_str = traceback.format_exc()
|
||||
result.exception_str = traceback.format_exc()
|
||||
|
||||
# handle exit and reporting
|
||||
elapsed = time.monotonic() - start
|
||||
if result.exception_str is not None:
|
||||
# print with exception
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
@@ -1118,13 +1045,15 @@ def connector_indexing_proxy_task(
|
||||
# print without exception
|
||||
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled"
|
||||
@@ -1132,25 +1061,6 @@ def connector_indexing_proxy_task(
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Indexing watchdog - activity timeout exceeded: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
log_builder.build(
|
||||
"Indexing watchdog - transient exception marking index attempt as failed"
|
||||
)
|
||||
)
|
||||
job.cancel()
|
||||
else:
|
||||
pass
|
||||
|
||||
task_logger.info(
|
||||
log_builder.build(
|
||||
@@ -1185,7 +1095,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
|
||||
try:
|
||||
locked = True
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
|
||||
for attempt in old_attempts:
|
||||
task_logger.info(
|
||||
@@ -1221,5 +1131,5 @@ def cleanup_checkpoint_task(
|
||||
self: Task, *, index_attempt_id: int, tenant_id: str | None
|
||||
) -> None:
|
||||
"""Clean up a checkpoint for a given index attempt"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
cleanup_checkpoint(db_session, index_attempt_id)
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisConstants
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
@@ -93,25 +93,27 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
return unfenced_attempts
|
||||
|
||||
|
||||
class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
class IndexingCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
redis_connector: RedisConnector,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.parent_pid = parent_pid
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.redis_client = redis_client
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = f"{self.__class__.__name__}.__init__"
|
||||
self.last_tag: str = "IndexingCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
@@ -125,8 +127,8 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
|
||||
# with daemon=True. It seems likely some indexing tasks will need to spawn other processes
|
||||
# eventually, which daemon=True prevents, so leave this code in until we're ready to test it.
|
||||
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
|
||||
# so leave this code in until we're ready to test it.
|
||||
|
||||
# if self.parent_pid:
|
||||
# # check if the parent pid is alive so we aren't running as a zombie
|
||||
@@ -141,6 +143,8 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
# self.last_parent_check = now
|
||||
|
||||
try:
|
||||
self.redis_connector.prune.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
@@ -152,7 +156,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"{self.__class__.__name__} - lock.reacquire exceptioned: "
|
||||
f"IndexingCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
@@ -163,24 +167,6 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
|
||||
class IndexingCallback(IndexingCallbackBase):
|
||||
def __init__(
|
||||
self,
|
||||
parent_pid: int,
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
):
|
||||
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
|
||||
|
||||
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
self.redis_connector_index.set_active()
|
||||
self.redis_connector_index.set_connector_active()
|
||||
super().progress(tag, amount)
|
||||
self.redis_client.incrby(
|
||||
self.redis_connector_index.generator_progress_key, amount
|
||||
)
|
||||
@@ -332,7 +318,7 @@ def validate_indexing_fences(
|
||||
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
continue
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_indexing_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
|
||||
@@ -8,7 +8,7 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import LLMProvider
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | N
|
||||
return None
|
||||
|
||||
# Then update the database with the fetched models
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Get the default LLM provider
|
||||
default_provider = (
|
||||
db_session.query(LLMProvider)
|
||||
|
||||
@@ -26,8 +26,7 @@ from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_db_current_time
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -43,6 +42,7 @@ from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
|
||||
|
||||
@@ -668,7 +668,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
task_logger.info("Starting background monitoring")
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock_monitoring: RedisLock = r.lock(
|
||||
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
|
||||
@@ -683,7 +683,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
try:
|
||||
# Get Redis client for Celery broker
|
||||
redis_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
redis_std = get_redis_client()
|
||||
redis_std = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Define metric collection functions and their dependencies
|
||||
metric_functions: list[Callable[[], list[Metric]]] = [
|
||||
@@ -693,7 +693,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
]
|
||||
|
||||
# Collect and log each metric
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
for metric_fn in metric_functions:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
@@ -771,11 +771,12 @@ def cloud_check_alembic() -> bool | None:
|
||||
if tenant_id is None:
|
||||
continue
|
||||
|
||||
with get_session_with_shared_schema() as session:
|
||||
with get_session_with_tenant(tenant_id=None) as session:
|
||||
try:
|
||||
result = session.execute(
|
||||
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
|
||||
)
|
||||
|
||||
result_scalar: str | None = result.scalar_one_or_none()
|
||||
if result_scalar is None:
|
||||
raise ValueError("Alembic version should not be None.")
|
||||
|
||||
@@ -15,7 +15,7 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PostgresAdvisoryLocks
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -36,7 +36,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
@@ -41,7 +41,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
@@ -55,7 +55,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import get_redis_replica_client
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import format_error_for_logging
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -63,12 +62,6 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class PruneCallback(IndexingCallbackBase):
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
self.redis_connector.prune.set_active()
|
||||
super().progress(tag, amount)
|
||||
|
||||
|
||||
"""Jobs / utils for kicking off pruning tasks."""
|
||||
|
||||
|
||||
@@ -115,8 +108,8 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
@@ -134,14 +127,14 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# but pruning only kicks off once per hour
|
||||
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
|
||||
cc_pair_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
for cc_pair_entry in cc_pairs:
|
||||
cc_pair_ids.append(cc_pair_entry.id)
|
||||
|
||||
for cc_pair_id in cc_pair_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
@@ -189,20 +182,18 @@ def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
|
||||
key_str = key_bytes.decode("utf-8")
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(f"Unexpected pruning check exception: {error_msg}")
|
||||
except Exception:
|
||||
task_logger.exception("Unexpected exception during pruning check")
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
task_logger.info(f"check_for_pruning finished: tenant={tenant_id}")
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -304,19 +295,13 @@ def try_creating_prune_generator_task(
|
||||
redis_connector.prune.set_fence(payload)
|
||||
|
||||
payload_id = payload.id
|
||||
except Exception as e:
|
||||
error_msg = format_error_for_logging(e)
|
||||
task_logger.warning(
|
||||
f"Unexpected try_creating_prune_generator_task exception: cc_pair={cc_pair.id} {error_msg}"
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
|
||||
return None
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
task_logger.info(
|
||||
f"try_creating_prune_generator_task finished: cc_pair={cc_pair.id} payload_id={payload_id}"
|
||||
)
|
||||
|
||||
return payload_id
|
||||
|
||||
|
||||
@@ -352,7 +337,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
r = get_redis_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
@@ -410,7 +395,7 @@ def connector_pruning_generator_task(
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
@@ -440,7 +425,6 @@ def connector_pruning_generator_task(
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector_source={cc_pair.connector.source}"
|
||||
)
|
||||
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session,
|
||||
cc_pair.connector.source,
|
||||
@@ -450,11 +434,12 @@ def connector_pruning_generator_task(
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
redis_connector.new_index(search_settings.id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings.id)
|
||||
|
||||
callback = PruneCallback(
|
||||
callback = IndexingCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
redis_connector_index,
|
||||
lock,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import time
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
|
||||
import httpx
|
||||
@@ -28,7 +27,7 @@ from onyx.db.document import mark_document_as_modified
|
||||
from onyx.db.document import mark_document_as_synced
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.document_index.interfaces import VespaDocumentFields
|
||||
@@ -46,24 +45,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
|
||||
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
class OnyxCeleryTaskCompletionStatus(str, Enum):
|
||||
"""The different statuses the watchdog can finish with.
|
||||
|
||||
TODO: create broader success/failure/abort categories
|
||||
"""
|
||||
|
||||
UNDEFINED = "undefined"
|
||||
|
||||
SUCCEEDED = "succeeded"
|
||||
|
||||
SKIPPED = "skipped"
|
||||
|
||||
SOFT_TIME_LIMIT = "soft_time_limit"
|
||||
|
||||
NON_RETRYABLE_EXCEPTION = "non_retryable_exception"
|
||||
RETRYABLE_EXCEPTION = "retryable_exception"
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
@@ -97,10 +78,8 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
chunks_affected = 0
|
||||
|
||||
@@ -126,14 +105,10 @@ def document_by_cc_pair_cleanup_task(
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=[document_id],
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
elif count > 1:
|
||||
action = "update"
|
||||
|
||||
@@ -177,11 +152,10 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
else:
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
|
||||
pass
|
||||
|
||||
db_session.commit()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
@@ -193,79 +167,57 @@ def document_by_cc_pair_cleanup_task(
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
|
||||
return False
|
||||
except Exception as ex:
|
||||
e: Exception | None = None
|
||||
while True:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
break
|
||||
|
||||
task_logger.exception(
|
||||
f"document_by_cc_pair_cleanup_task exceptioned: doc={document_id}"
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
|
||||
if (
|
||||
self.max_retries is not None
|
||||
and self.request.retries >= self.max_retries
|
||||
):
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"doc={document_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
# in vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
break
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(f"Unexpected exception: doc={document_id}")
|
||||
|
||||
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
|
||||
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
|
||||
break # we won't hit this, but it looks weird not to have it
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"document_by_cc_pair_cleanup_task completed: status={completion_status.value} doc={document_id}"
|
||||
)
|
||||
|
||||
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
else:
|
||||
# This is the last attempt! mark the document as dirty in the db so that it
|
||||
# eventually gets fixed out of band via stale document reconciliation
|
||||
task_logger.warning(
|
||||
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
|
||||
f"doc={document_id}"
|
||||
)
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# delete the cc pair relationship now and let reconciliation clean it up
|
||||
# in vespa
|
||||
delete_document_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_id=document_id,
|
||||
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
)
|
||||
mark_document_as_modified(document_id, db_session)
|
||||
return False
|
||||
|
||||
task_logger.info(f"document_by_cc_pair_cleanup_task finished: doc={document_id}")
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
|
||||
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
@@ -35,7 +34,7 @@ from onyx.db.document_set import fetch_document_sets
|
||||
from onyx.db.document_set import fetch_document_sets_for_document
|
||||
from onyx.db.document_set import get_document_set_by_id
|
||||
from onyx.db.document_set import mark_document_set_as_synced
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import DocumentSet
|
||||
@@ -85,8 +84,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
time_start = time.monotonic()
|
||||
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_replica = get_redis_replica_client(tenant_id=tenant_id)
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
|
||||
@@ -99,7 +98,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
try:
|
||||
# 1/3: KICKOFF
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_stale_document_sync_tasks(
|
||||
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -107,7 +106,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
# region document set scan
|
||||
lock_beat.reacquire()
|
||||
document_set_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# check if any document sets are not synced
|
||||
document_set_info = fetch_document_sets(
|
||||
user_id=None, db_session=db_session, include_outdated=True
|
||||
@@ -118,7 +117,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
for document_set_id in document_set_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_document_set_sync_tasks(
|
||||
self.app, document_set_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -137,7 +136,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
pass
|
||||
else:
|
||||
usergroup_ids: list[int] = []
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session=db_session, only_up_to_date=False
|
||||
)
|
||||
@@ -147,7 +146,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
|
||||
for usergroup_id in usergroup_ids:
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
try_generate_user_group_sync_tasks(
|
||||
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
|
||||
)
|
||||
@@ -168,7 +167,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
|
||||
monitor_connector_taskset(r)
|
||||
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
|
||||
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
|
||||
monitor_usergroup_taskset = (
|
||||
@@ -178,7 +177,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | No
|
||||
noop_fallback,
|
||||
)
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
|
||||
|
||||
except SoftTimeLimitExceeded:
|
||||
@@ -524,14 +523,12 @@ def monitor_document_set_taskset(
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, *, tenant_id: str | None
|
||||
self: Task, document_id: str, tenant_id: str | None
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
active_search_settings = get_active_search_settings(db_session)
|
||||
doc_index = get_default_document_index(
|
||||
search_settings=active_search_settings.primary,
|
||||
@@ -543,103 +540,75 @@ def vespa_metadata_sync_task(
|
||||
|
||||
doc = get_document(document_id, db_session)
|
||||
if not doc:
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=no_operation "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
|
||||
else:
|
||||
# document set sync
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
return False
|
||||
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
# document set sync
|
||||
doc_sets = fetch_document_sets_for_document(document_id, db_session)
|
||||
update_doc_sets: set[str] = set(doc_sets)
|
||||
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
|
||||
except Exception as ex:
|
||||
e: Exception | None = None
|
||||
while True:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
break
|
||||
|
||||
task_logger.exception(
|
||||
f"vespa_metadata_sync_task exceptioned: doc={document_id}"
|
||||
# User group sync
|
||||
doc_access = get_access_for_document(
|
||||
document_id=document_id, db_session=db_session
|
||||
)
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
|
||||
if (
|
||||
self.max_retries is not None
|
||||
and self.request.retries >= self.max_retries
|
||||
):
|
||||
completion_status = (
|
||||
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
|
||||
)
|
||||
fields = VespaDocumentFields(
|
||||
document_sets=update_doc_sets,
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
|
||||
break # we won't hit this, but it looks weird not to have it
|
||||
finally:
|
||||
task_logger.info(
|
||||
f"vespa_metadata_sync_task completed: status={completion_status.value} doc={document_id}"
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
chunks_affected = retry_index.update_single(
|
||||
document_id,
|
||||
tenant_id=tenant_id,
|
||||
chunk_count=doc.chunk_count,
|
||||
fields=fields,
|
||||
)
|
||||
|
||||
# update db last. Worst case = we crash right before this and
|
||||
# the sync might repeat again later
|
||||
mark_document_as_synced(document_id, db_session)
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action=sync "
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
except Exception as ex:
|
||||
e: Exception | None = None
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
|
||||
)
|
||||
|
||||
# only set the inner exception if it is of type Exception
|
||||
e_temp = ex.last_attempt.exception()
|
||||
if isinstance(e_temp, Exception):
|
||||
e = e_temp
|
||||
else:
|
||||
e = ex
|
||||
|
||||
if isinstance(e, httpx.HTTPStatusError):
|
||||
if e.response.status_code == HTTPStatus.BAD_REQUEST:
|
||||
task_logger.exception(
|
||||
f"Non-retryable HTTPStatusError: "
|
||||
f"doc={document_id} "
|
||||
f"status={e.response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
task_logger.exception(
|
||||
f"Unexpected exception during vespa metadata sync: doc={document_id}"
|
||||
)
|
||||
|
||||
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
|
||||
return False
|
||||
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
|
||||
countdown = 2 ** (self.request.retries + 4)
|
||||
self.retry(exc=e, countdown=countdown)
|
||||
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from onyx.db.background_error import create_background_error
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
def emit_background_error(
|
||||
@@ -11,10 +9,5 @@ def emit_background_error(
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
except IntegrityError as e:
|
||||
# Log an error if the cc_pair_id was deleted or any other exception occurs
|
||||
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
|
||||
create_background_error(db_session, error_message, None)
|
||||
with get_session_with_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
|
||||
@@ -17,9 +17,6 @@ from typing import Optional
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -57,15 +54,6 @@ def _initializer(
|
||||
kwargs = {}
|
||||
|
||||
logger.info("Initializing spawned worker child process.")
|
||||
# 1. Get tenant_id from args or fallback to default
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
for arg in reversed(args):
|
||||
if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX):
|
||||
tenant_id = arg
|
||||
break
|
||||
|
||||
# 2. Set the tenant context before running anything
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Reset the engine in the child process
|
||||
SqlEngine.reset_engine()
|
||||
@@ -93,8 +81,6 @@ def _initializer(
|
||||
queue.put(error_msg) # Send the exception to the parent process
|
||||
|
||||
sys.exit(255) # use 255 to indicate a generic exception
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _run_in_process(
|
||||
|
||||
@@ -15,13 +15,11 @@ from onyx.background.indexing.memory_tracer import MemoryTracer
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
|
||||
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
|
||||
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.factory import instantiate_connector
|
||||
from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
@@ -30,7 +28,7 @@ from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import create_index_attempt_error
|
||||
from onyx.db.index_attempt import get_index_attempt
|
||||
@@ -88,11 +86,6 @@ def _get_connector_runner(
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
if not INTEGRATION_TESTS_MODE:
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
||||
@@ -251,7 +244,7 @@ def _run_indexing(
|
||||
"""
|
||||
start_time = time.monotonic() # jsut used for logging
|
||||
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt_start:
|
||||
raise ValueError(
|
||||
@@ -377,7 +370,7 @@ def _run_indexing(
|
||||
document_count = 0
|
||||
chunk_count = 0
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
|
||||
if not index_attempt:
|
||||
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
|
||||
@@ -437,7 +430,7 @@ def _run_indexing(
|
||||
raise ConnectorStopSignal("Connector stop signal detected")
|
||||
|
||||
# TODO: should we move this into the above callback instead?
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# will exception if the connector/index attempt is marked as paused/failed
|
||||
_check_connector_and_attempt_status(
|
||||
db_session_temp, ctx, index_attempt_id
|
||||
@@ -446,7 +439,7 @@ def _run_indexing(
|
||||
# save record of any failures at the connector level
|
||||
if failure is not None:
|
||||
total_failures += 1
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
ctx.cc_pair_id,
|
||||
@@ -510,7 +503,7 @@ def _run_indexing(
|
||||
if document.id not in failed_document_ids
|
||||
]
|
||||
for document_id in successful_document_ids:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
if document_id in doc_id_to_unresolved_errors:
|
||||
logger.info(
|
||||
f"Resolving IndexAttemptError for document '{document_id}'"
|
||||
@@ -523,7 +516,7 @@ def _run_indexing(
|
||||
# add brand new failures
|
||||
if index_pipeline_result.failures:
|
||||
total_failures += len(index_pipeline_result.failures)
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
for failure in index_pipeline_result.failures:
|
||||
create_index_attempt_error(
|
||||
index_attempt_id,
|
||||
@@ -540,7 +533,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
# This new value is updated every batch, so UI can refresh per batch update
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
|
||||
# so we need either to commit() or to use a new session
|
||||
update_docs_indexed(
|
||||
@@ -562,7 +555,7 @@ def _run_indexing(
|
||||
check_checkpoint_size(checkpoint)
|
||||
|
||||
# save latest checkpoint
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
save_checkpoint(
|
||||
db_session=db_session_temp,
|
||||
index_attempt_id=index_attempt_id,
|
||||
@@ -574,29 +567,9 @@ def _run_indexing(
|
||||
"Connector run exceptioned after elapsed time: "
|
||||
f"{time.monotonic() - start_time} seconds"
|
||||
)
|
||||
if isinstance(e, ConnectorValidationError):
|
||||
# On validation errors during indexing, we want to cancel the indexing attempt
|
||||
# and mark the CCPair as invalid. This prevents the connector from being
|
||||
# used in the future until the credentials are updated.
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
reason=str(e),
|
||||
)
|
||||
|
||||
if ctx.is_primary:
|
||||
update_connector_credential_pair(
|
||||
db_session=db_session_temp,
|
||||
connector_id=ctx.connector_id,
|
||||
credential_id=ctx.credential_id,
|
||||
status=ConnectorCredentialPairStatus.INVALID,
|
||||
)
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
|
||||
elif isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
if isinstance(e, ConnectorStopSignal):
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -614,7 +587,7 @@ def _run_indexing(
|
||||
memory_tracer.stop()
|
||||
raise e
|
||||
else:
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session_temp,
|
||||
@@ -636,7 +609,7 @@ def _run_indexing(
|
||||
memory_tracer.stop()
|
||||
|
||||
elapsed_time = time.monotonic() - start_time
|
||||
with get_session_with_current_tenant() as db_session_temp:
|
||||
with get_session_with_tenant(tenant_id) as db_session_temp:
|
||||
# resolve entity-based errors
|
||||
for error in entity_based_unresolved_errors:
|
||||
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
|
||||
@@ -696,7 +669,7 @@ def run_indexing_entrypoint(
|
||||
TaskAttemptSingleton.set_cc_and_index_id(
|
||||
index_attempt_id, connector_credential_pair_id
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# TODO: remove long running session entirely
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
@@ -717,7 +690,7 @@ def run_indexing_entrypoint(
|
||||
f"credentials='{credential_id}'"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
|
||||
|
||||
logger.info(
|
||||
|
||||
@@ -27,10 +27,8 @@ from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -82,26 +80,6 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
rerank_settings = search_request.rerank_settings
|
||||
|
||||
using_cloud_reranking = (
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
# remove SearchQuery entirely)
|
||||
if (
|
||||
force_use_tool.force_use
|
||||
and search_tool
|
||||
and force_use_tool.args
|
||||
and force_use_tool.tool_name == search_tool.name
|
||||
and QUERY_FIELD in force_use_tool.args
|
||||
):
|
||||
search_request.query = force_use_tool.args[QUERY_FIELD]
|
||||
|
||||
self.graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
prompt_builder=prompt_builder,
|
||||
@@ -116,6 +94,7 @@ class Answer:
|
||||
force_use_tool=force_use_tool,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
assert db_session, "db_session must be provided for agentic persistence"
|
||||
self.graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -125,7 +104,6 @@ class Answer:
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
allow_refinement=True,
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
|
||||
@@ -190,8 +190,7 @@ def create_chat_chain(
|
||||
and previous_message.message_type == MessageType.ASSISTANT
|
||||
and mainline_messages
|
||||
):
|
||||
if current_message.refined_answer_improvement:
|
||||
mainline_messages[-1] = current_message
|
||||
mainline_messages[-1] = current_message
|
||||
else:
|
||||
mainline_messages.append(current_message)
|
||||
|
||||
|
||||
@@ -142,15 +142,6 @@ class MessageResponseIDInfo(BaseModel):
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class AgentMessageIDInfo(BaseModel):
|
||||
level: int
|
||||
message_id: int
|
||||
|
||||
|
||||
class AgenticMessageResponseIDInfo(BaseModel):
|
||||
agentic_message_ids: list[AgentMessageIDInfo]
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
|
||||
@@ -7,12 +7,10 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
from onyx.chat.models import AgenticMessageResponseIDInfo
|
||||
from onyx.chat.models import AgentMessageIDInfo
|
||||
from onyx.chat.models import AgentSearchPacket
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import AnswerPostInfo
|
||||
@@ -145,10 +143,9 @@ from onyx.utils.long_term_log import LongTermLogger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.timing import log_function_time
|
||||
from onyx.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
|
||||
def _translate_citations(
|
||||
@@ -310,7 +307,6 @@ ChatPacket = (
|
||||
| CustomToolResponse
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
| AgenticMessageResponseIDInfo
|
||||
| StreamStopInfo
|
||||
| AgentSearchPacket
|
||||
)
|
||||
@@ -346,7 +342,7 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
@@ -635,7 +631,6 @@ def stream_chat_message_objects(
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
reserved_message_id=reserved_message_id,
|
||||
is_agentic=new_msg_req.use_agentic_search,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
@@ -747,13 +742,14 @@ def stream_chat_message_objects(
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
system_message=default_build_system_message(prompt_config),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
raw_user_uploaded_files=latest_query_files or [],
|
||||
single_message_history=single_message_history,
|
||||
)
|
||||
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
@@ -869,6 +865,7 @@ def stream_chat_message_objects(
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
info.ai_message_files.extend(
|
||||
[
|
||||
@@ -1018,7 +1015,7 @@ def stream_chat_message_objects(
|
||||
if info.message_specific_citations
|
||||
else None
|
||||
),
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
error=None,
|
||||
tool_call=(
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
|
||||
@@ -1036,7 +1033,6 @@ def stream_chat_message_objects(
|
||||
next_level = 1
|
||||
prev_message = gen_ai_response_message
|
||||
agent_answers = answer.llm_answer_by_level()
|
||||
agentic_message_ids = []
|
||||
while next_level in agent_answers:
|
||||
next_answer = agent_answers[next_level]
|
||||
info = info_by_subq[
|
||||
@@ -1057,12 +1053,7 @@ def stream_chat_message_objects(
|
||||
citations=info.message_specific_citations.citation_map
|
||||
if info.message_specific_citations
|
||||
else None,
|
||||
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
|
||||
refined_answer_improvement=refined_answer_improvement,
|
||||
is_agentic=True,
|
||||
)
|
||||
agentic_message_ids.append(
|
||||
AgentMessageIDInfo(level=next_level, message_id=next_answer_message.id)
|
||||
)
|
||||
next_level += 1
|
||||
prev_message = next_answer_message
|
||||
@@ -1070,9 +1061,11 @@ def stream_chat_message_objects(
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield translate_db_message_to_chat_message_detail(gen_ai_response_message)
|
||||
yield msg_detail_response
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(error_msg)
|
||||
|
||||
@@ -12,7 +12,6 @@ from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_toke
|
||||
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
from onyx.llm.models import PreviousMessage
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_message_tokens
|
||||
@@ -20,7 +19,6 @@ from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
@@ -33,16 +31,8 @@ from onyx.tools.tool import Tool
|
||||
|
||||
def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
# for o-series markdown generation
|
||||
if (
|
||||
llm_config.model_provider == OPENAI_PROVIDER_NAME
|
||||
and llm_config.model_name.startswith("o")
|
||||
):
|
||||
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
|
||||
tag_handled_prompt = handle_onyx_date_awareness(
|
||||
system_prompt,
|
||||
prompt_config,
|
||||
@@ -120,8 +110,21 @@ class AnswerPromptBuilder:
|
||||
),
|
||||
)
|
||||
|
||||
self.update_system_prompt(system_message)
|
||||
self.update_user_prompt(user_message)
|
||||
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
|
||||
(
|
||||
system_message,
|
||||
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
|
||||
)
|
||||
if system_message
|
||||
else None
|
||||
)
|
||||
self.user_message_and_token_cnt = (
|
||||
user_message,
|
||||
check_message_tokens(
|
||||
user_message,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
|
||||
|
||||
@@ -31,9 +31,22 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
AGENT_ANSWER_GENERATION_BY_FAST_LLM = (
|
||||
os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true"
|
||||
)
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
|
||||
@@ -165,172 +178,80 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
) # 3
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
|
||||
) # 30
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
) # 12
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
) # 6
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
|
||||
) # 1
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
|
||||
) # 4
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
) # 8
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -626,8 +626,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
|
||||
|
||||
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
|
||||
|
||||
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
|
||||
@@ -98,18 +98,9 @@ CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
|
||||
|
||||
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
|
||||
|
||||
|
||||
# hard timeout applied by the watchdog to the indexing connector run
|
||||
# to handle hung connectors
|
||||
CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
|
||||
|
||||
# soft timeout for the lock taken by the indexing connector run
|
||||
# allows the lock to eventually expire if the managing code around it dies
|
||||
# needs to be long enough to cover the maximum time it takes to download an object
|
||||
# if we can get callbacks as object bytes download, we could lower this a lot.
|
||||
# CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 15 minutes
|
||||
# hard termination should always fire first if the connector is hung
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
|
||||
|
||||
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
|
||||
|
||||
# how long a task should wait for associated fence to be ready
|
||||
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
|
||||
@@ -342,9 +333,6 @@ class OnyxRedisSignals:
|
||||
BLOCK_PRUNING = "signal:block_pruning"
|
||||
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
|
||||
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
|
||||
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
|
||||
"signal:block_validate_connector_deletion_fences"
|
||||
)
|
||||
|
||||
|
||||
class OnyxRedisConstants:
|
||||
|
||||
@@ -7,18 +7,11 @@ from typing import Optional
|
||||
|
||||
import boto3 # type: ignore
|
||||
from botocore.client import Config # type: ignore
|
||||
from botocore.exceptions import ClientError
|
||||
from botocore.exceptions import NoCredentialsError
|
||||
from botocore.exceptions import PartialCredentialsError
|
||||
from mypy_boto3_s3 import S3Client # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import BlobType
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -247,73 +240,6 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.s3_client is None:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Blob storage credentials not loaded."
|
||||
)
|
||||
|
||||
if not self.bucket_name:
|
||||
raise ConnectorValidationError(
|
||||
"No bucket name was provided in connector settings."
|
||||
)
|
||||
|
||||
try:
|
||||
# We only fetch one object/page as a light-weight validation step.
|
||||
# This ensures we trigger typical S3 permission checks (ListObjectsV2, etc.).
|
||||
self.s3_client.list_objects_v2(
|
||||
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
|
||||
)
|
||||
|
||||
except NoCredentialsError:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"No valid blob storage credentials found or provided to boto3."
|
||||
)
|
||||
except PartialCredentialsError:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Partial or incomplete blob storage credentials provided to boto3."
|
||||
)
|
||||
except ClientError as e:
|
||||
error_code = e.response["Error"].get("Code", "")
|
||||
status_code = e.response["ResponseMetadata"].get("HTTPStatusCode")
|
||||
|
||||
# Most common S3 error cases
|
||||
if error_code in [
|
||||
"AccessDenied",
|
||||
"InvalidAccessKeyId",
|
||||
"SignatureDoesNotMatch",
|
||||
]:
|
||||
if status_code == 403 or error_code == "AccessDenied":
|
||||
raise InsufficientPermissionsError(
|
||||
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
|
||||
"Please check your bucket policy and/or IAM policy."
|
||||
)
|
||||
if status_code == 401 or error_code == "SignatureDoesNotMatch":
|
||||
raise CredentialExpiredError(
|
||||
"Provided blob storage credentials appear invalid or expired."
|
||||
)
|
||||
|
||||
raise CredentialExpiredError(
|
||||
f"Credential issue encountered ({error_code})."
|
||||
)
|
||||
|
||||
if error_code == "NoSuchBucket" or status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
|
||||
)
|
||||
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Catch-all for anything not captured by the above
|
||||
# Since we are unsure of the error and it may not disable the connector,
|
||||
# raise an unexpected error (does not disable connector)
|
||||
raise UnexpectedError(
|
||||
f"Unexpected error during blob storage settings validation: {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
credentials_dict = {
|
||||
|
||||
@@ -5,8 +5,6 @@ import requests
|
||||
|
||||
class BookStackClientRequestFailedError(ConnectionError):
|
||||
def __init__(self, status: int, error: str) -> None:
|
||||
self.status_code = status
|
||||
self.error = error
|
||||
super().__init__(
|
||||
"BookStack Client request failed with status {status}: {error}".format(
|
||||
status=status, error=error
|
||||
|
||||
@@ -7,11 +7,7 @@ from typing import Any
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.bookstack.client import BookStackApiClient
|
||||
from onyx.connectors.bookstack.client import BookStackClientRequestFailedError
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -218,39 +214,3 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
break
|
||||
else:
|
||||
time.sleep(0.2)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Validate that the BookStack credentials and connector settings are correct.
|
||||
Specifically checks that we can make an authenticated request to BookStack.
|
||||
"""
|
||||
if not self.bookstack_client:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"BookStack credentials have not been loaded."
|
||||
)
|
||||
|
||||
try:
|
||||
# Attempt to fetch a small batch of books (arbitrary endpoint) to verify credentials
|
||||
_ = self.bookstack_client.get(
|
||||
"/books", params={"count": "1", "offset": "0"}
|
||||
)
|
||||
|
||||
except BookStackClientRequestFailedError as e:
|
||||
# Check for HTTP status codes
|
||||
if e.status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Your BookStack credentials appear to be invalid or expired (HTTP 401)."
|
||||
) from e
|
||||
elif e.status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"The configured BookStack token does not have sufficient permissions (HTTP 403)."
|
||||
) from e
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected BookStack error (status={e.status_code}): {e}"
|
||||
) from e
|
||||
|
||||
except Exception as exc:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected error while validating BookStack connector settings: {exc}"
|
||||
) from exc
|
||||
|
||||
@@ -4,8 +4,6 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from requests.exceptions import HTTPError
|
||||
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
@@ -18,10 +16,6 @@ from onyx.connectors.confluence.utils import build_confluence_document_id
|
||||
from onyx.connectors.confluence.utils import datetime_from_string
|
||||
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -403,33 +397,3 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
callback.progress("retrieve_all_slim_documents", 1)
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._confluence_client is None:
|
||||
raise ConnectorMissingCredentialError("Confluence credentials not loaded.")
|
||||
|
||||
try:
|
||||
spaces = self._confluence_client.get_all_spaces(limit=1)
|
||||
except HTTPError as e:
|
||||
status_code = e.response.status_code if e.response else None
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Confluence credentials (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Insufficient permissions to access Confluence resources (HTTP 403)."
|
||||
)
|
||||
raise UnexpectedError(
|
||||
f"Unexpected Confluence error (status={status_code}): {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise UnexpectedError(
|
||||
f"Unexpected error while validating Confluence settings: {e}"
|
||||
)
|
||||
|
||||
if not spaces or not spaces.get("results"):
|
||||
raise ConnectorValidationError(
|
||||
"No Confluence spaces found. Either your credentials lack permissions, or "
|
||||
"there truly are no spaces in this Confluence instance."
|
||||
)
|
||||
|
||||
@@ -8,12 +8,8 @@ from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from requests import HTTPError
|
||||
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,16 +29,6 @@ class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ConfluenceUser(BaseModel):
|
||||
user_id: str # accountId in Cloud, userKey in Server
|
||||
username: str | None # Confluence Cloud doesn't give usernames
|
||||
display_name: str
|
||||
# Confluence Data Center doesn't give email back by default,
|
||||
# have to fetch it with a different endpoint
|
||||
email: str | None
|
||||
type: str
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
@@ -163,7 +149,7 @@ class OnyxConfluence(Confluence):
|
||||
)
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
self, url_suffix: str, limit: int | None = None
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
This will paginate through the top level query.
|
||||
@@ -238,41 +224,9 @@ class OnyxConfluence(Confluence):
|
||||
raise e
|
||||
|
||||
# yield the results individually
|
||||
results = cast(list[dict[str, Any]], next_response.get("results", []))
|
||||
yield from results
|
||||
yield from next_response.get("results", [])
|
||||
|
||||
old_url_suffix = url_suffix
|
||||
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
|
||||
|
||||
# make sure we don't update the start by more than the amount
|
||||
# of results we were able to retrieve. The Confluence API has a
|
||||
# weird behavior where if you pass in a limit that is too large for
|
||||
# the configured server, it will artificially limit the amount of
|
||||
# results returned BUT will not apply this to the start parameter.
|
||||
# This will cause us to miss results.
|
||||
if url_suffix and "start" in url_suffix:
|
||||
new_start = get_start_param_from_url(url_suffix)
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
if new_start - previous_start > len(results):
|
||||
logger.warning(
|
||||
f"Start was updated by more than the amount of results "
|
||||
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
|
||||
f"Previous Start: {previous_start}, Len Results: {len(results)}."
|
||||
)
|
||||
|
||||
# Update the url_suffix to use the adjusted start
|
||||
adjusted_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
url_suffix, "start", str(adjusted_start)
|
||||
)
|
||||
|
||||
# some APIs don't properly paginate, so we need to manually update the `start` param
|
||||
if auto_paginate and len(results) > 0:
|
||||
previous_start = get_start_param_from_url(old_url_suffix)
|
||||
updated_start = previous_start + len(results)
|
||||
url_suffix = update_param_in_path(
|
||||
old_url_suffix, "start", str(updated_start)
|
||||
)
|
||||
url_suffix = next_response.get("_links", {}).get("next")
|
||||
|
||||
def paginated_cql_retrieval(
|
||||
self,
|
||||
@@ -321,97 +275,21 @@ class OnyxConfluence(Confluence):
|
||||
self,
|
||||
expand: str | None = None,
|
||||
limit: int | None = None,
|
||||
) -> Iterator[ConfluenceUser]:
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
The search/user endpoint can be used to fetch users.
|
||||
It's a seperate endpoint from the content/search endpoint used only for users.
|
||||
Otherwise it's very similar to the content/search endpoint.
|
||||
"""
|
||||
if self.cloud:
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
# endpoint doesn't properly paginate, so we need to manually update the `start` param
|
||||
# thus the auto_paginate flag
|
||||
for user_result in self._paginate_url(url, limit, auto_paginate=True):
|
||||
# Example response:
|
||||
# {
|
||||
# 'user': {
|
||||
# 'type': 'known',
|
||||
# 'accountId': '712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'accountType': 'atlassian',
|
||||
# 'email': 'chris@danswer.ai',
|
||||
# 'publicName': 'Chris Weaver',
|
||||
# 'profilePicture': {
|
||||
# 'path': '/wiki/aa-avatar/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'width': 48,
|
||||
# 'height': 48,
|
||||
# 'isDefault': False
|
||||
# },
|
||||
# 'displayName': 'Chris Weaver',
|
||||
# 'isExternalCollaborator': False,
|
||||
# '_expandable': {
|
||||
# 'operations': '',
|
||||
# 'personalSpace': ''
|
||||
# },
|
||||
# '_links': {
|
||||
# 'self': 'https://danswerai.atlassian.net/wiki/rest/api/user?accountId=712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d'
|
||||
# }
|
||||
# },
|
||||
# 'title': 'Chris Weaver',
|
||||
# 'excerpt': '',
|
||||
# 'url': '/people/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
|
||||
# 'breadcrumbs': [],
|
||||
# 'entityType': 'user',
|
||||
# 'iconCssClass': 'aui-icon content-type-profile',
|
||||
# 'lastModified': '2025-02-18T04:08:03.579Z',
|
||||
# 'score': 0.0
|
||||
# }
|
||||
user = user_result["user"]
|
||||
yield ConfluenceUser(
|
||||
user_id=user["accountId"],
|
||||
username=None,
|
||||
display_name=user["displayName"],
|
||||
email=user.get("email"),
|
||||
type=user["accountType"],
|
||||
)
|
||||
else:
|
||||
# https://developer.atlassian.com/server/confluence/rest/v900/api-group-user/#api-rest-api-user-list-get
|
||||
# ^ is only available on data center deployments
|
||||
# Example response:
|
||||
# [
|
||||
# {
|
||||
# 'type': 'known',
|
||||
# 'username': 'admin',
|
||||
# 'userKey': '40281082950c5fe901950c61c55d0000',
|
||||
# 'profilePicture': {
|
||||
# 'path': '/images/icons/profilepics/default.svg',
|
||||
# 'width': 48,
|
||||
# 'height': 48,
|
||||
# 'isDefault': True
|
||||
# },
|
||||
# 'displayName': 'Admin Test',
|
||||
# '_links': {
|
||||
# 'self': 'http://localhost:8090/rest/api/user?key=40281082950c5fe901950c61c55d0000'
|
||||
# },
|
||||
# '_expandable': {
|
||||
# 'status': ''
|
||||
# }
|
||||
# }
|
||||
# ]
|
||||
for user in self._paginate_url("rest/api/user/list", limit):
|
||||
yield ConfluenceUser(
|
||||
user_id=user["userKey"],
|
||||
username=user["username"],
|
||||
display_name=user["displayName"],
|
||||
email=None,
|
||||
type=user.get("type", "user"),
|
||||
)
|
||||
cql = "type=user"
|
||||
url = "rest/api/search/user" if self.cloud else "rest/api/search"
|
||||
expand_string = f"&expand={expand}" if expand else ""
|
||||
url += f"?cql={cql}{expand_string}"
|
||||
yield from self._paginate_url(url, limit)
|
||||
|
||||
def paginated_groups_by_user_retrieval(
|
||||
self,
|
||||
user_id: str, # accountId in Cloud, userKey in Server
|
||||
user: dict[str, Any],
|
||||
limit: int | None = None,
|
||||
) -> Iterator[dict[str, Any]]:
|
||||
"""
|
||||
@@ -419,7 +297,7 @@ class OnyxConfluence(Confluence):
|
||||
It's a confluence specific endpoint that can be used to fetch groups.
|
||||
"""
|
||||
user_field = "accountId" if self.cloud else "key"
|
||||
user_value = user_id
|
||||
user_value = user["accountId"] if self.cloud else user["userKey"]
|
||||
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
|
||||
user_query = f"{user_field}={quote(user_value)}"
|
||||
|
||||
@@ -545,15 +423,11 @@ def build_confluence_client(
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> OnyxConfluence:
|
||||
try:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(str(e))
|
||||
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
|
||||
@@ -2,10 +2,7 @@ import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
|
||||
@@ -13,13 +10,13 @@ from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
OnyxConfluence,
|
||||
)
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -27,7 +24,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
@@ -50,7 +47,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
@@ -81,7 +78,7 @@ def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
@@ -194,7 +191,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
@@ -282,32 +279,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
datetime_object = datetime_object.astimezone(timezone.utc)
|
||||
|
||||
return datetime_object
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
return parse_qs(parsed_url.query).get(param, [None])[0]
|
||||
|
||||
|
||||
def get_start_param_from_url(url: str) -> int:
|
||||
"""Get the start parameter from a url"""
|
||||
start_str = get_single_param_from_url(url, "start")
|
||||
if start_str is None:
|
||||
return 0
|
||||
return int(start_str)
|
||||
|
||||
|
||||
def update_param_in_path(path: str, param: str, value: str) -> str:
|
||||
"""Update a parameter in a path. Path should look something like:
|
||||
|
||||
/api/rest/users?start=0&limit=10
|
||||
"""
|
||||
parsed_url = urlparse(path)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
query_params[param] = [value]
|
||||
return (
|
||||
path.split("?")[0]
|
||||
+ "?"
|
||||
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
@@ -25,22 +24,16 @@ def datetime_to_utc(dt: datetime) -> datetime:
|
||||
|
||||
|
||||
def time_str_to_utc(datetime_str: str) -> datetime:
|
||||
# Remove all timezone abbreviations in parentheses
|
||||
datetime_str = re.sub(r"\([A-Z]+\)", "", datetime_str).strip()
|
||||
|
||||
# Remove any remaining parentheses and their contents
|
||||
datetime_str = re.sub(r"\(.*?\)", "", datetime_str).strip()
|
||||
|
||||
try:
|
||||
dt = parse(datetime_str)
|
||||
except ValueError:
|
||||
# Fix common format issues (e.g. "0000" => "+0000")
|
||||
# Handle malformed timezone by attempting to fix common format issues
|
||||
if "0000" in datetime_str:
|
||||
datetime_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(datetime_str)
|
||||
# Convert "0000" to "+0000" for proper timezone parsing
|
||||
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
|
||||
dt = parse(fixed_dt_str)
|
||||
else:
|
||||
raise
|
||||
|
||||
return datetime_to_utc(dt)
|
||||
|
||||
|
||||
|
||||
@@ -4,15 +4,11 @@ from typing import Any
|
||||
|
||||
from dropbox import Dropbox # type: ignore
|
||||
from dropbox.exceptions import ApiError # type:ignore
|
||||
from dropbox.exceptions import AuthError # type:ignore
|
||||
from dropbox.files import FileMetadata # type:ignore
|
||||
from dropbox.files import FolderMetadata # type:ignore
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialInvalidError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -145,29 +141,6 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
|
||||
return None
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.dropbox_client is None:
|
||||
raise ConnectorMissingCredentialError("Dropbox credentials not loaded.")
|
||||
|
||||
try:
|
||||
self.dropbox_client.files_list_folder(path="", limit=1)
|
||||
except AuthError as e:
|
||||
logger.exception("Failed to validate Dropbox credentials")
|
||||
raise CredentialInvalidError(f"Dropbox credential is invalid: {e.error}")
|
||||
except ApiError as e:
|
||||
if (
|
||||
e.error is not None
|
||||
and "insufficient_permissions" in str(e.error).lower()
|
||||
):
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Dropbox token does not have sufficient permissions."
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected Dropbox error during validation: {e.user_message_text or e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise Exception(f"Unexpected error during Dropbox settings validation: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -1,49 +0,0 @@
|
||||
class ValidationError(Exception):
|
||||
"""General exception for validation errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class ConnectorValidationError(ValidationError):
|
||||
"""General exception for connector validation errors."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
class UnexpectedError(ValidationError):
|
||||
"""Raised when an unexpected error occurs during connector validation.
|
||||
|
||||
Unexpected errors don't necessarily mean the credential is invalid,
|
||||
but rather that there was an error during the validation process
|
||||
or we encountered a currently unhandled error case.
|
||||
"""
|
||||
|
||||
def __init__(self, message: str = "Unexpected error during connector validation"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CredentialInvalidError(ConnectorValidationError):
|
||||
"""Raised when a connector's credential is invalid."""
|
||||
|
||||
def __init__(self, message: str = "Credential is invalid"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class CredentialExpiredError(ConnectorValidationError):
|
||||
"""Raised when a connector's credential is expired."""
|
||||
|
||||
def __init__(self, message: str = "Credential has expired"):
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class InsufficientPermissionsError(ConnectorValidationError):
|
||||
"""Raised when the credential does not have sufficient API permissions."""
|
||||
|
||||
def __init__(
|
||||
self, message: str = "Insufficient permissions for the requested operation"
|
||||
):
|
||||
super().__init__(message)
|
||||
@@ -3,7 +3,6 @@ from typing import Type
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
@@ -18,7 +17,6 @@ from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
from onyx.connectors.dropbox.connector import DropboxConnector
|
||||
from onyx.connectors.egnyte.connector import EgnyteConnector
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.fireflies.connector import FirefliesConnector
|
||||
from onyx.connectors.freshdesk.connector import FreshdeskConnector
|
||||
@@ -54,9 +52,7 @@ from onyx.connectors.wikipedia.connector import WikipediaConnector
|
||||
from onyx.connectors.xenforo.connector import XenforoConnector
|
||||
from onyx.connectors.zendesk.connector import ZendeskConnector
|
||||
from onyx.connectors.zulip.connector import ZulipConnector
|
||||
from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.models import Credential
|
||||
|
||||
|
||||
@@ -178,53 +174,3 @@ def instantiate_connector(
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def validate_ccpair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None,
|
||||
enforce_creation: bool = True,
|
||||
) -> bool:
|
||||
if INTEGRATION_TESTS_MODE:
|
||||
return True
|
||||
|
||||
# Validate the connector settings
|
||||
connector = fetch_connector_by_id(connector_id, db_session)
|
||||
credential = fetch_credential_by_id(
|
||||
credential_id,
|
||||
db_session,
|
||||
)
|
||||
|
||||
if not connector:
|
||||
raise ValueError("Connector not found")
|
||||
|
||||
if (
|
||||
connector.source == DocumentSource.INGESTION_API
|
||||
or connector.source == DocumentSource.MOCK_CONNECTOR
|
||||
):
|
||||
return True
|
||||
|
||||
if not credential:
|
||||
raise ValueError("Credential not found")
|
||||
|
||||
try:
|
||||
runnable_connector = instantiate_connector(
|
||||
db_session=db_session,
|
||||
source=connector.source,
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
if enforce_creation:
|
||||
raise ConnectorValidationError(str(e))
|
||||
else:
|
||||
return False
|
||||
|
||||
runnable_connector.validate_connector_settings()
|
||||
return True
|
||||
|
||||
@@ -181,7 +181,7 @@ class LocalFileConnector(LoadConnector):
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
with get_session_with_tenant(self.tenant_id) as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
|
||||
@@ -187,12 +187,12 @@ class FirefliesConnector(PollConnector, LoadConnector):
|
||||
return self._process_transcripts()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
start_datetime = datetime.fromtimestamp(
|
||||
start_unixtime, tz=timezone.utc
|
||||
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
|
||||
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%dT%H:%M:%S.000Z"
|
||||
)
|
||||
|
||||
|
||||
@@ -229,20 +229,16 @@ class GitbookConnector(LoadConnector, PollConnector):
|
||||
|
||||
try:
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content")
|
||||
pages: list[dict[str, Any]] = content.get("pages", [])
|
||||
pages = content.get("pages", [])
|
||||
|
||||
current_batch: list[Document] = []
|
||||
for page in pages:
|
||||
updated_at = datetime.fromisoformat(page["updatedAt"])
|
||||
|
||||
while pages:
|
||||
page = pages.pop(0)
|
||||
|
||||
updated_at_raw = page.get("updatedAt")
|
||||
if updated_at_raw is None:
|
||||
# if updatedAt is not present, that means the page has never been edited
|
||||
continue
|
||||
|
||||
updated_at = datetime.fromisoformat(updated_at_raw)
|
||||
if start and updated_at < start:
|
||||
continue
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
return
|
||||
if end and updated_at > end:
|
||||
continue
|
||||
|
||||
@@ -254,8 +250,6 @@ class GitbookConnector(LoadConnector, PollConnector):
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
pages.extend(page.get("pages", []))
|
||||
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ from typing import cast
|
||||
from github import Github
|
||||
from github import RateLimitExceededException
|
||||
from github import Repository
|
||||
from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
@@ -17,10 +16,6 @@ from github.PullRequest import PullRequest
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -31,6 +26,7 @@ from onyx.connectors.models import Section
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -124,7 +120,7 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repo_name: str | None = None,
|
||||
repo_name: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
@@ -162,81 +158,53 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def _get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
|
||||
try:
|
||||
# Try to get organization first
|
||||
try:
|
||||
org = github_client.get_organization(self.repo_owner)
|
||||
return list(org.get_repos())
|
||||
except GithubException:
|
||||
# If not an org, try as a user
|
||||
user = github_client.get_user(self.repo_owner)
|
||||
return list(user.get_repos())
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _fetch_from_github(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = (
|
||||
[self._get_github_repo(self.github_client)]
|
||||
if self.repo_name
|
||||
else self._get_all_repos(self.github_client)
|
||||
)
|
||||
repo = self._get_github_repo(self.github_client)
|
||||
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
if self.include_prs:
|
||||
pull_requests = repo.get_pulls(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
for pr_batch in _batch_github_objects(
|
||||
pull_requests, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch: list[Document] = []
|
||||
for pr in pr_batch:
|
||||
if start is not None and pr.updated_at < start:
|
||||
yield doc_batch
|
||||
return
|
||||
if end is not None and pr.updated_at > end:
|
||||
continue
|
||||
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
|
||||
yield doc_batch
|
||||
|
||||
if self.include_issues:
|
||||
logger.info(f"Fetching issues for repo: {repo.name}")
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
if self.include_issues:
|
||||
issues = repo.get_issues(
|
||||
state=self.state_filter, sort="updated", direction="desc"
|
||||
)
|
||||
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
break
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
for issue_batch in _batch_github_objects(
|
||||
issues, self.github_client, self.batch_size
|
||||
):
|
||||
doc_batch = []
|
||||
for issue in issue_batch:
|
||||
issue = cast(Issue, issue)
|
||||
if start is not None and issue.updated_at < start:
|
||||
yield doc_batch
|
||||
return
|
||||
if end is not None and issue.updated_at > end:
|
||||
continue
|
||||
if issue.pull_request is not None:
|
||||
# PRs are handled separately
|
||||
continue
|
||||
doc_batch.append(_convert_issue_to_document(issue))
|
||||
yield doc_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
return self._fetch_from_github()
|
||||
@@ -258,63 +226,6 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
|
||||
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
|
||||
|
||||
if not self.repo_owner:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: 'repo_owner' must be provided."
|
||||
)
|
||||
|
||||
try:
|
||||
if self.repo_name:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
else:
|
||||
# Try to get organization first
|
||||
try:
|
||||
org = self.github_client.get_organization(self.repo_owner)
|
||||
org.get_repos().totalCount # Just check if we can access repos
|
||||
except GithubException:
|
||||
# If not an org, try as a user
|
||||
user = self.github_client.get_user(self.repo_owner)
|
||||
user.get_repos().totalCount # Just check if we can access repos
|
||||
|
||||
except RateLimitExceededException:
|
||||
raise UnexpectedError(
|
||||
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
|
||||
except GithubException as e:
|
||||
if e.status == 401:
|
||||
raise CredentialExpiredError(
|
||||
"GitHub credential appears to be invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif e.status == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
||||
)
|
||||
elif e.status == 404:
|
||||
if self.repo_name:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub user or organization not found: {self.repo_owner}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
||||
)
|
||||
except Exception as exc:
|
||||
raise Exception(
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -297,7 +297,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
full_threads = execute_paginated_retrieval(
|
||||
retrieval_function=gmail_service.users().threads().get,
|
||||
@@ -305,7 +304,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
userId=user_email,
|
||||
fields=THREAD_FIELDS,
|
||||
id=thread["id"],
|
||||
continue_on_404_or_403=True,
|
||||
)
|
||||
# full_threads is an iterator containing a single thread
|
||||
# so we need to convert it to a list and grab the first element
|
||||
@@ -337,7 +335,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
userId=user_email,
|
||||
fields=THREAD_LIST_FIELDS,
|
||||
q=query,
|
||||
continue_on_404_or_403=True,
|
||||
):
|
||||
doc_batch.append(
|
||||
SlimDocument(
|
||||
|
||||
@@ -13,9 +13,6 @@ from googleapiclient.errors import HttpError # type: ignore
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.google_drive.doc_conversion import build_slim_document
|
||||
from onyx.connectors.google_drive.doc_conversion import (
|
||||
convert_drive_item_to_document,
|
||||
@@ -45,7 +42,6 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
@@ -141,7 +137,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
raise ValueError(
|
||||
"Google Drive connector received old input parameters. "
|
||||
"Please visit the docs for help with the new setup: "
|
||||
f"{SCOPE_DOC_URL}"
|
||||
@@ -155,7 +151,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
and not my_drive_emails
|
||||
and not shared_drive_urls
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
raise ValueError(
|
||||
"Nothing to index. Please specify at least one of the following: "
|
||||
"include_shared_drives, include_my_drives, include_files_shared_with_me, "
|
||||
"shared_folder_urls, or my_drive_emails"
|
||||
@@ -224,14 +220,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return self._creds
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
|
||||
try:
|
||||
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
"Primary admin email missing, "
|
||||
"should not call this property "
|
||||
"before calling load_credentials"
|
||||
)
|
||||
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
|
||||
|
||||
self._creds, new_creds_dict = get_google_creds(
|
||||
credentials=credentials,
|
||||
@@ -613,50 +602,3 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._creds is None:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Google Drive credentials not loaded."
|
||||
)
|
||||
|
||||
if self._primary_admin_email is None:
|
||||
raise ConnectorValidationError(
|
||||
"Primary admin email not found in credentials. "
|
||||
"Ensure DB_CREDENTIALS_PRIMARY_ADMIN_KEY is set."
|
||||
)
|
||||
|
||||
try:
|
||||
drive_service = get_drive_service(self._creds, self._primary_admin_email)
|
||||
drive_service.files().list(pageSize=1, fields="files(id)").execute()
|
||||
|
||||
if isinstance(self._creds, ServiceAccountCredentials):
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
|
||||
except HttpError as e:
|
||||
status_code = e.resp.status if e.resp else None
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Google Drive credentials (401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Google Drive app lacks required permissions (403). "
|
||||
"Please ensure the necessary scopes are granted and Drive "
|
||||
"apps are enabled."
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected Google Drive error (status={status_code}): {e}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# Check for scope-related hints from the error message
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
raise InsufficientPermissionsError(
|
||||
"Google Drive credentials are missing required scopes. "
|
||||
f"{ONYX_SCOPE_INSTRUCTIONS}"
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected error during Google Drive validation: {e}"
|
||||
)
|
||||
|
||||
@@ -87,18 +87,16 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
contact = api_client.crm.contacts.basic_api.get_by_id(
|
||||
contact_id=contact.id
|
||||
)
|
||||
email = contact.properties.get("email")
|
||||
if email is not None:
|
||||
associated_emails.append(email)
|
||||
associated_emails.append(contact.properties["email"])
|
||||
|
||||
if notes:
|
||||
for note in notes.results:
|
||||
note = api_client.crm.objects.notes.basic_api.get_by_id(
|
||||
note_id=note.id, properties=["content", "hs_body_preview"]
|
||||
)
|
||||
preview = note.properties.get("hs_body_preview")
|
||||
if preview is not None:
|
||||
associated_notes.append(preview)
|
||||
if note.properties["hs_body_preview"] is None:
|
||||
continue
|
||||
associated_notes.append(note.properties["hs_body_preview"])
|
||||
|
||||
associated_emails_str = " ,".join(associated_emails)
|
||||
associated_notes_str = " ".join(associated_notes)
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
|
||||
SecondsSinceUnixEpoch = float
|
||||
|
||||
GenerateDocumentsOutput = Iterator[list[Document]]
|
||||
@@ -44,14 +45,6 @@ class BaseConnector(abc.ABC):
|
||||
raise RuntimeError(custom_parser_req_msg)
|
||||
return metadata_lines
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Override this if your connector needs to validate credentials or settings.
|
||||
Raise an exception if invalid, otherwise do nothing.
|
||||
|
||||
Default is a no-op (always successful).
|
||||
"""
|
||||
|
||||
|
||||
# Large set update or reindex, generally pulling a complete state or from a savestate file
|
||||
class LoadConnector(BaseConnector):
|
||||
|
||||
@@ -7,7 +7,6 @@ from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
@@ -16,15 +15,10 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rl_requests,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.batching import batch_generator
|
||||
@@ -622,64 +616,6 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
else:
|
||||
break
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if not self.headers.get("Authorization"):
|
||||
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
|
||||
|
||||
try:
|
||||
# We'll do a minimal search call (page_size=1) to confirm accessibility
|
||||
if self.root_page_id:
|
||||
# If root_page_id is set, fetch the specific page
|
||||
res = rl_requests.get(
|
||||
f"https://api.notion.com/v1/pages/{self.root_page_id}",
|
||||
headers=self.headers,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
else:
|
||||
# If root_page_id is not set, perform a minimal search
|
||||
test_query = {
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
"page_size": 1,
|
||||
}
|
||||
res = rl_requests.post(
|
||||
"https://api.notion.com/v1/search",
|
||||
headers=self.headers,
|
||||
json=test_query,
|
||||
timeout=_NOTION_CALL_TIMEOUT,
|
||||
)
|
||||
res.raise_for_status()
|
||||
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
status_code = http_err.response.status_code if http_err.response else None
|
||||
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Notion credential appears to be invalid or expired (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Notion token does not have sufficient permissions (HTTP 403)."
|
||||
)
|
||||
elif status_code == 404:
|
||||
# Typically means resource not found or not shared. Could be root_page_id is invalid.
|
||||
raise ConnectorValidationError(
|
||||
"Notion resource not found or not shared with the integration (HTTP 404)."
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Notion rate-limits being exceeded (HTTP 429). "
|
||||
"Please try again later."
|
||||
)
|
||||
else:
|
||||
raise UnexpectedError(
|
||||
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
|
||||
) from http_err
|
||||
|
||||
except Exception as exc:
|
||||
raise UnexpectedError(
|
||||
f"Unexpected error during Notion settings validation: {exc}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -12,9 +12,6 @@ from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -29,6 +26,7 @@ from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
|
||||
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
|
||||
from onyx.connectors.onyx_jira.utils import build_jira_client
|
||||
from onyx.connectors.onyx_jira.utils import build_jira_url
|
||||
from onyx.connectors.onyx_jira.utils import extract_jira_project
|
||||
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
|
||||
from onyx.connectors.onyx_jira.utils import get_comment_strs
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -159,8 +157,7 @@ def fetch_jira_issues_batch(
|
||||
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
def __init__(
|
||||
self,
|
||||
jira_base_url: str,
|
||||
project_key: str | None = None,
|
||||
jira_project_url: str,
|
||||
comment_email_blacklist: list[str] | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
# if a ticket has one of the labels specified in this list, we will just
|
||||
@@ -169,12 +166,11 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
|
||||
self.jira_project = project_key
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
|
||||
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
|
||||
self._jira_client: JIRA | None = None
|
||||
self._comment_email_blacklist = comment_email_blacklist or []
|
||||
|
||||
self.labels_to_skip = set(labels_to_skip)
|
||||
|
||||
@property
|
||||
def comment_email_blacklist(self) -> tuple:
|
||||
@@ -189,9 +185,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
@property
|
||||
def quoted_jira_project(self) -> str:
|
||||
# Quote the project name to handle reserved words
|
||||
if not self.jira_project:
|
||||
return ""
|
||||
return f'"{self.jira_project}"'
|
||||
return f'"{self._jira_project}"'
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._jira_client = build_jira_client(
|
||||
@@ -200,14 +194,8 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_jql_query(self) -> str:
|
||||
"""Get the JQL query based on whether a specific project is set"""
|
||||
if self.jira_project:
|
||||
return f"project = {self.quoted_jira_project}"
|
||||
return "" # Empty string means all accessible projects
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
jql = self._get_jql_query()
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
@@ -234,10 +222,11 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
|
||||
base_jql = self._get_jql_query()
|
||||
jql = (
|
||||
f"{base_jql} AND " if base_jql else ""
|
||||
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
|
||||
f"project = {self.quoted_jira_project} AND "
|
||||
f"updated >= '{start_date_str}' AND "
|
||||
f"updated <= '{end_date_str}'"
|
||||
)
|
||||
|
||||
document_batch = []
|
||||
for doc in fetch_jira_issues_batch(
|
||||
@@ -260,7 +249,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
jql = self._get_jql_query()
|
||||
jql = f"project = {self.quoted_jira_project}"
|
||||
|
||||
slim_doc_batch = []
|
||||
for issue in _paginate_jql_search(
|
||||
@@ -283,67 +272,13 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
yield slim_doc_batch
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self._jira_client is None:
|
||||
raise ConnectorMissingCredentialError("Jira")
|
||||
|
||||
# If a specific project is set, validate it exists
|
||||
if self.jira_project:
|
||||
try:
|
||||
self.jira_client.project(self.jira_project)
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", None)
|
||||
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Jira credential appears to be expired or invalid (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
|
||||
)
|
||||
elif status_code == 404:
|
||||
raise ConnectorValidationError(
|
||||
f"Jira project not found with key: {self.jira_project}"
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
|
||||
else:
|
||||
# If no project specified, validate we can access the Jira API
|
||||
try:
|
||||
# Try to list projects to validate access
|
||||
self.jira_client.projects()
|
||||
except Exception as e:
|
||||
status_code = getattr(e, "status_code", None)
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Jira credential appears to be expired or invalid (HTTP 401)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your Jira token does not have sufficient permissions to list projects (HTTP 403)."
|
||||
)
|
||||
elif status_code == 429:
|
||||
raise ConnectorValidationError(
|
||||
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
connector = JiraConnector(
|
||||
jira_base_url=os.environ["JIRA_BASE_URL"],
|
||||
project_key=os.environ.get("JIRA_PROJECT_KEY"),
|
||||
comment_email_blacklist=[],
|
||||
os.environ["JIRA_PROJECT_URL"], comment_email_blacklist=[]
|
||||
)
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
"jira_user_email": os.environ["JIRA_USER_EMAIL"],
|
||||
|
||||
@@ -18,10 +18,6 @@ from slack_sdk.errors import SlackApiError
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
@@ -86,14 +82,14 @@ def get_channels(
|
||||
get_public: bool = True,
|
||||
get_private: bool = True,
|
||||
) -> list[ChannelType]:
|
||||
"""Get all channels in the workspace."""
|
||||
"""Get all channels in the workspace"""
|
||||
channels: list[dict[str, Any]] = []
|
||||
channel_types = []
|
||||
if get_public:
|
||||
channel_types.append("public_channel")
|
||||
if get_private:
|
||||
channel_types.append("private_channel")
|
||||
# Try fetching both public and private channels first:
|
||||
# try getting private channels as well at first
|
||||
try:
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
@@ -101,19 +97,19 @@ def get_channels(
|
||||
channel_types=channel_types,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.info(
|
||||
f"Unable to fetch private channels due to: {e}. Trying again without private channels."
|
||||
)
|
||||
logger.info(f"Unable to fetch private channels due to - {e}")
|
||||
logger.info("trying again without private channels")
|
||||
if get_public:
|
||||
channel_types = ["public_channel"]
|
||||
else:
|
||||
logger.warning("No channels to fetch.")
|
||||
logger.warning("No channels to fetch")
|
||||
return []
|
||||
channels = _collect_paginated_channels(
|
||||
client=client,
|
||||
exclude_archived=exclude_archived,
|
||||
channel_types=channel_types,
|
||||
)
|
||||
|
||||
return channels
|
||||
|
||||
|
||||
@@ -670,86 +666,6 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
1. Verify the bot token is valid for the workspace (via auth_test).
|
||||
2. Ensure the bot has enough scope to list channels.
|
||||
3. Check that every channel specified in self.channels exists.
|
||||
"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
|
||||
|
||||
try:
|
||||
# 1) Validate connection to workspace
|
||||
auth_response = self.client.auth_test()
|
||||
if not auth_response.get("ok", False):
|
||||
error_msg = auth_response.get(
|
||||
"error", "Unknown error from Slack auth_test"
|
||||
)
|
||||
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
|
||||
|
||||
# 2) Minimal test to confirm listing channels works
|
||||
test_resp = self.client.conversations_list(
|
||||
limit=1, types=["public_channel"]
|
||||
)
|
||||
if not test_resp.get("ok", False):
|
||||
error_msg = test_resp.get("error", "Unknown error from Slack")
|
||||
if error_msg == "invalid_auth":
|
||||
raise ConnectorValidationError(
|
||||
f"Invalid Slack bot token ({error_msg})."
|
||||
)
|
||||
elif error_msg == "not_authed":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid or expired Slack bot token ({error_msg})."
|
||||
)
|
||||
raise UnexpectedError(f"Slack API returned a failure: {error_msg}")
|
||||
|
||||
# 3) If channels are specified, verify each is accessible
|
||||
if self.channels:
|
||||
accessible_channels = get_channels(
|
||||
client=self.client,
|
||||
exclude_archived=True,
|
||||
get_public=True,
|
||||
get_private=True,
|
||||
)
|
||||
# For quick lookups by name or ID, build a map:
|
||||
accessible_channel_names = {ch["name"] for ch in accessible_channels}
|
||||
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
|
||||
|
||||
for user_channel in self.channels:
|
||||
if (
|
||||
user_channel not in accessible_channel_names
|
||||
and user_channel not in accessible_channel_ids
|
||||
):
|
||||
raise ConnectorValidationError(
|
||||
f"Channel '{user_channel}' not found or inaccessible in this workspace."
|
||||
)
|
||||
|
||||
except SlackApiError as e:
|
||||
slack_error = e.response.get("error", "")
|
||||
if slack_error == "missing_scope":
|
||||
raise InsufficientPermissionsError(
|
||||
"Slack bot token lacks the necessary scope to list/access channels. "
|
||||
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
|
||||
)
|
||||
elif slack_error == "invalid_auth":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid Slack bot token ({slack_error})."
|
||||
)
|
||||
elif slack_error == "not_authed":
|
||||
raise CredentialExpiredError(
|
||||
f"Invalid or expired Slack bot token ({slack_error})."
|
||||
)
|
||||
raise UnexpectedError(
|
||||
f"Unexpected Slack error '{slack_error}' during settings validation."
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise UnexpectedError(
|
||||
f"Unexpected error during Slack settings validation: {e}"
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import Any
|
||||
|
||||
import msal # type: ignore
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
from office365.runtime.client_request_exception import ClientRequestException # type: ignore
|
||||
from office365.teams.channels.channel import Channel # type: ignore
|
||||
from office365.teams.chats.messages.message import ChatMessage # type: ignore
|
||||
from office365.teams.team import Team # type: ignore
|
||||
@@ -13,10 +12,6 @@ from office365.teams.team import Team # type: ignore
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -284,50 +279,6 @@ class TeamsConnector(LoadConnector, PollConnector):
|
||||
end_datetime = datetime.fromtimestamp(end, timezone.utc)
|
||||
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
if self.graph_client is None:
|
||||
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
|
||||
|
||||
try:
|
||||
# Minimal call to confirm we can retrieve Teams
|
||||
found_teams = self._get_all_teams()
|
||||
|
||||
except ClientRequestException as e:
|
||||
status_code = e.response.status_code
|
||||
if status_code == 401:
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
|
||||
)
|
||||
elif status_code == 403:
|
||||
raise InsufficientPermissionsError(
|
||||
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
|
||||
)
|
||||
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
if (
|
||||
"unauthorized" in error_str
|
||||
or "401" in error_str
|
||||
or "invalid_grant" in error_str
|
||||
):
|
||||
raise CredentialExpiredError(
|
||||
"Invalid or expired Microsoft Teams credentials."
|
||||
)
|
||||
elif "forbidden" in error_str or "403" in error_str:
|
||||
raise InsufficientPermissionsError(
|
||||
"App lacks required permissions to read from Microsoft Teams."
|
||||
)
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected error during Teams validation: {e}"
|
||||
)
|
||||
|
||||
if not found_teams:
|
||||
raise ConnectorValidationError(
|
||||
"No Teams found for the given credentials. "
|
||||
"Either there are no Teams in this tenant, or your app does not have permission to view them."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))
|
||||
|
||||
@@ -25,10 +25,6 @@ from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_TOKEN_URL
|
||||
from onyx.configs.app_configs import WEB_CONNECTOR_VALIDATE_URLS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
@@ -41,8 +37,6 @@ from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
|
||||
|
||||
|
||||
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
|
||||
# Given a base site, index everything under that path
|
||||
@@ -176,35 +170,26 @@ def start_playwright() -> Tuple[Playwright, BrowserContext]:
|
||||
|
||||
|
||||
def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
|
||||
try:
|
||||
response = requests.get(sitemap_url)
|
||||
response.raise_for_status()
|
||||
response = requests.get(sitemap_url)
|
||||
response.raise_for_status()
|
||||
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
urls = [
|
||||
_ensure_absolute_url(sitemap_url, loc_tag.text)
|
||||
for loc_tag in soup.find_all("loc")
|
||||
]
|
||||
soup = BeautifulSoup(response.content, "html.parser")
|
||||
urls = [
|
||||
_ensure_absolute_url(sitemap_url, loc_tag.text)
|
||||
for loc_tag in soup.find_all("loc")
|
||||
]
|
||||
|
||||
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
|
||||
# the given url doesn't look like a sitemap, let's try to find one
|
||||
urls = list_pages_for_site(sitemap_url)
|
||||
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
|
||||
# the given url doesn't look like a sitemap, let's try to find one
|
||||
urls = list_pages_for_site(sitemap_url)
|
||||
|
||||
if len(urls) == 0:
|
||||
raise ValueError(
|
||||
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
|
||||
)
|
||||
|
||||
return urls
|
||||
except requests.RequestException as e:
|
||||
raise RuntimeError(f"Failed to fetch sitemap from {sitemap_url}: {e}")
|
||||
except ValueError as e:
|
||||
raise RuntimeError(f"Error processing sitemap {sitemap_url}: {e}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Unexpected error while processing sitemap {sitemap_url}: {e}"
|
||||
if len(urls) == 0:
|
||||
raise ValueError(
|
||||
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
|
||||
)
|
||||
|
||||
return urls
|
||||
|
||||
|
||||
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
|
||||
if not urlparse(maybe_relative_url).netloc:
|
||||
@@ -240,14 +225,10 @@ class WebConnector(LoadConnector):
|
||||
web_connector_type: str = WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value,
|
||||
mintlify_cleanup: bool = True, # Mostly ok to apply to other websites as well
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
scroll_before_scraping: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.mintlify_cleanup = mintlify_cleanup
|
||||
self.batch_size = batch_size
|
||||
self.recursive = False
|
||||
self.scroll_before_scraping = scroll_before_scraping
|
||||
self.web_connector_type = web_connector_type
|
||||
|
||||
if web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value:
|
||||
self.recursive = True
|
||||
@@ -363,18 +344,6 @@ class WebConnector(LoadConnector):
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
|
||||
if self.scroll_before_scraping:
|
||||
scroll_attempts = 0
|
||||
previous_height = page.evaluate("document.body.scrollHeight")
|
||||
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
|
||||
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
|
||||
page.wait_for_load_state("networkidle", timeout=30000)
|
||||
new_height = page.evaluate("document.body.scrollHeight")
|
||||
if new_height == previous_height:
|
||||
break # Stop scrolling when no more content is loaded
|
||||
previous_height = new_height
|
||||
scroll_attempts += 1
|
||||
|
||||
content = page.content()
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
@@ -433,56 +402,6 @@ class WebConnector(LoadConnector):
|
||||
raise RuntimeError(last_error)
|
||||
raise RuntimeError("No valid pages found.")
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
# Make sure we have at least one valid URL to check
|
||||
if not self.to_visit_list:
|
||||
raise ConnectorValidationError(
|
||||
"No URL configured. Please provide at least one valid URL."
|
||||
)
|
||||
|
||||
if (
|
||||
self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value
|
||||
or self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value
|
||||
):
|
||||
return None
|
||||
|
||||
# We'll just test the first URL for connectivity and correctness
|
||||
test_url = self.to_visit_list[0]
|
||||
|
||||
# Check that the URL is allowed and well-formed
|
||||
try:
|
||||
protected_url_check(test_url)
|
||||
except ValueError as e:
|
||||
raise ConnectorValidationError(
|
||||
f"Protected URL check failed for '{test_url}': {e}"
|
||||
)
|
||||
except ConnectionError as e:
|
||||
# Typically DNS or other network issues
|
||||
raise ConnectorValidationError(str(e))
|
||||
|
||||
# Make a quick request to see if we get a valid response
|
||||
try:
|
||||
check_internet_connection(test_url)
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
if "401" in err_str:
|
||||
raise CredentialExpiredError(
|
||||
f"Unauthorized access to '{test_url}': {e}"
|
||||
)
|
||||
elif "403" in err_str:
|
||||
raise InsufficientPermissionsError(
|
||||
f"Forbidden access to '{test_url}': {e}"
|
||||
)
|
||||
elif "404" in err_str:
|
||||
raise ConnectorValidationError(f"Page not found for '{test_url}': {e}")
|
||||
elif "Max retries exceeded" in err_str and "NameResolutionError" in err_str:
|
||||
raise ConnectorValidationError(
|
||||
f"Unable to resolve hostname for '{test_url}'. Please check the URL and your internet connection."
|
||||
)
|
||||
else:
|
||||
# Could be a 5xx or another error, treat as unexpected
|
||||
raise UnexpectedError(f"Unexpected error validating '{test_url}': {e}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = WebConnector("https://docs.onyx.app/")
|
||||
|
||||
@@ -23,6 +23,7 @@ from onyx.context.search.preprocessing.access_filters import (
|
||||
from onyx.context.search.retrieval.search_runner import (
|
||||
remove_stop_words_and_punctuation,
|
||||
)
|
||||
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.llm.interfaces import LLM
|
||||
@@ -34,7 +35,6 @@ from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -166,7 +166,7 @@ def retrieval_preprocessing(
|
||||
time_cutoff=time_filter or predicted_time_cutoff,
|
||||
tags=preset_filters.tags, # Tags are never auto-extracted
|
||||
access_control_list=user_acl_filters,
|
||||
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
|
||||
tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None,
|
||||
)
|
||||
|
||||
llm_evaluation_type = LLMEvaluationType.BASIC
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
@@ -71,7 +71,7 @@ def insert_api_key(
|
||||
std_password_helper = PasswordHelper()
|
||||
|
||||
# Get tenant_id from context var (will be default schema for single tenant)
|
||||
tenant_id = get_current_tenant_id()
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
|
||||
api_key_user_id = uuid.uuid4()
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user