Compare commits

..

2 Commits

Author SHA1 Message Date
pablodanswer
d1b9f11750 k 2025-02-17 19:17:48 -08:00
pablodanswer
9b067d8179 update error 2025-02-17 19:17:48 -08:00
272 changed files with 2517 additions and 7487 deletions

View File

@@ -145,7 +145,7 @@ jobs:
run: |
cd deployment/docker_compose
docker compose -f docker-compose.multitenant-dev.yml -p onyx-stack down -v
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
@@ -157,7 +157,6 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
@@ -200,7 +199,7 @@ jobs:
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |

View File

@@ -74,9 +74,7 @@ jobs:
python -m pip install --upgrade pip
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -1,29 +1,18 @@
name: Model Server Tests
name: Connector Tests
on:
schedule:
# This cron expression runs the job daily at 16:00 UTC (9am PT)
- cron: "0 16 * * *"
workflow_dispatch:
inputs:
branch:
description: 'Branch to run the workflow on'
required: false
default: 'main'
env:
# Bedrock
AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }}
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
jobs:
model-check:
@@ -37,23 +26,6 @@ jobs:
- name: Checkout code
uses: actions/checkout@v4
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Model Server Docker image
run: |
docker pull onyxdotapp/onyx-model-server:latest
docker tag onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:test
- name: Set up Python
uses: actions/setup-python@v5
with:
@@ -69,49 +41,6 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
- name: Start Docker containers
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
AUTH_TYPE=basic \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:9000/api/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: |
@@ -127,23 +56,3 @@ jobs:
-H 'Content-type: application/json' \
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v

124
README.md
View File

@@ -24,93 +24,113 @@
</a>
</p>
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI platform connected to your company's docs, apps, and people.
Onyx provides a feature rich Chat interface and plugs into any LLM of your choice.
Keep knowledge and access controls sync-ed across over 40 connectors like Google Drive, Slack, Confluence, Salesforce, etc.
Create custom AI agents with unique prompts, knowledge, and actions that the agents can take.
Onyx can be deployed securely anywhere and for any scale - on a laptop, on-premise, or to cloud.
<strong>[Onyx](https://www.onyx.app/)</strong> (formerly Danswer) is the AI Assistant connected to your company's docs, apps, and people.
Onyx provides a Chat interface and plugs into any LLM of your choice. Onyx can be deployed anywhere and for any
scale - on a laptop, on-premise, or to cloud. Since you own the deployment, your user data and chats are fully in your
own control. Onyx is dual Licensed with most of it under MIT license and designed to be modular and easily extensible. The system also comes fully ready
for production usage with user authentication, role management (admin/basic users), chat persistence, and a UI for
configuring AI Assistants.
Onyx also serves as a Enterprise Search across all common workplace tools such as Slack, Google Drive, Confluence, etc.
By combining LLMs and team specific knowledge, Onyx becomes a subject matter expert for the team. Imagine ChatGPT if
it had access to your team's unique knowledge! It enables questions such as "A customer wants feature X, is this already
supported?" or "Where's the pull request for feature Y?"
<h3>Feature Highlights</h3>
<h3>Usage</h3>
**Deep research over your team's knowledge:**
Onyx Web App:
https://private-user-images.githubusercontent.com/32520769/414509312-48392e83-95d0-4fb5-8650-a396e05e0a32.mp4?jwt=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJnaXRodWIuY29tIiwiYXVkIjoicmF3LmdpdGh1YnVzZXJjb250ZW50LmNvbSIsImtleSI6ImtleTUiLCJleHAiOjE3Mzk5Mjg2MzYsIm5iZiI6MTczOTkyODMzNiwicGF0aCI6Ii8zMjUyMDc2OS80MTQ1MDkzMTItNDgzOTJlODMtOTVkMC00ZmI1LTg2NTAtYTM5NmUwNWUwYTMyLm1wND9YLUFtei1BbGdvcml0aG09QVdTNC1ITUFDLVNIQTI1NiZYLUFtei1DcmVkZW50aWFsPUFLSUFWQ09EWUxTQTUzUFFLNFpBJTJGMjAyNTAyMTklMkZ1cy1lYXN0LTElMkZzMyUyRmF3czRfcmVxdWVzdCZYLUFtei1EYXRlPTIwMjUwMjE5VDAxMjUzNlomWC1BbXotRXhwaXJlcz0zMDAmWC1BbXotU2lnbmF0dXJlPWFhMzk5Njg2Y2Y5YjFmNDNiYTQ2YzM5ZTg5YWJiYTU2NWMyY2YwNmUyODE2NWUxMDRiMWQxZWJmODI4YTA0MTUmWC1BbXotU2lnbmVkSGVhZGVycz1ob3N0In0.a9D8A0sgKE9AoaoE-mfFbJ6_OKYeqaf7TZ4Han2JfW8
https://github.com/onyx-dot-app/onyx/assets/32520769/563be14c-9304-47b5-bf0a-9049c2b6f410
Or, plug Onyx into your existing Slack workflows (more integrations to come 😁):
**Use Onyx as a secure AI Chat with any LLM:**
![Onyx Chat Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxChatSilentDemo.gif)
**Easily set up connectors to your apps:**
![Onyx Connector Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxConnectorSilentDemo.gif)
**Access Onyx where your team already works:**
![Onyx Bot Demo](https://github.com/onyx-dot-app/onyx/releases/download/v0.21.1/OnyxBot.png)
https://github.com/onyx-dot-app/onyx/assets/25087905/3e19739b-d178-4371-9a38-011430bdec1b
For more details on the Admin UI to manage connectors and users, check out our
<strong><a href="https://www.youtube.com/watch?v=geNzY1nbCnU">Full Video Demo</a></strong>!
## Deployment
**To try it out for free and get started in seconds, check out [Onyx Cloud](https://cloud.onyx.app/signup)**.
Onyx can also be run locally (even on a laptop) or deployed on a virtual machine with a single
Onyx can easily be run locally (even on a laptop) or deployed on a virtual machine with a single
`docker compose` command. Checkout our [docs](https://docs.onyx.app/quickstart) to learn more.
We also have built-in support for high-availability/scalable deployment on Kubernetes.
References [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment).
We also have built-in support for deployment on Kubernetes. Files for that can be found [here](https://github.com/onyx-dot-app/onyx/tree/main/deployment/kubernetes).
## 💃 Main Features
## 🔍 Other Notable Benefits of Onyx
- Custom deep learning models for indexing and inference time, only through Onyx + learning from user feedback.
- Flexible security features like SSO (OIDC/SAML/OAuth2), RBAC, encryption of credentials, etc.
- Knowledge curation features like document-sets, query history, usage analytics, etc.
- Scalable deployment options tested up to many tens of thousands users and hundreds of millions of documents.
- Chat UI with the ability to select documents to chat with.
- Create custom AI Assistants with different prompts and backing knowledge sets.
- Connect Onyx with LLM of your choice (self-host for a fully airgapped solution).
- Document Search + AI Answers for natural language queries.
- Connectors to all common workplace tools like Google Drive, Confluence, Slack, etc.
- Slack integration to get answers and search results directly in Slack.
## 🚧 Roadmap
- New methods in information retrieval (StructRAG, LightGraphRAG, etc.)
- Personalized Search
- Organizational understanding and ability to locate and suggest experts from your team.
- Code Search
- SQL and Structured Query Language
- Chat/Prompt sharing with specific teammates and user groups.
- Multimodal model support, chat with images, video etc.
- Choosing between LLMs and parameters during chat session.
- Tool calling and agent configurations options.
- Organizational understanding and ability to locate and suggest experts from your team.
## Other Notable Benefits of Onyx
- User Authentication with document level access management.
- Best in class Hybrid Search across all sources (BM-25 + prefix aware embedding models).
- Admin Dashboard to configure connectors, document-sets, access, etc.
- Custom deep learning models + learn from user feedback.
- Easy deployment and ability to host Onyx anywhere of your choosing.
## 🔌 Connectors
Keep knowledge and access up to sync across 40+ connectors:
Efficiently pulls the latest changes from:
- Slack
- GitHub
- Google Drive
- Confluence
- Slack
- Gmail
- Salesforce
- Microsoft Sharepoint
- Github
- Jira
- Zendesk
- Gmail
- Notion
- Gong
- Microsoft Teams
- Dropbox
- Slab
- Linear
- Productboard
- Guru
- Bookstack
- Document360
- Sharepoint
- Hubspot
- Local Files
- Websites
- And more ...
See the full list [here](https://docs.onyx.app/connectors).
## 📚 Editions
## 📚 Licensing
There are two editions of Onyx:
- Onyx Community Edition (CE) is available freely under the MIT Expat license. Simply follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations.
For feature details, check out [our website](https://www.onyx.app/pricing).
- Onyx Community Edition (CE) is available freely under the MIT Expat license. This version has ALL the core features discussed above. This is the version of Onyx you will get if you follow the Deployment guide above.
- Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. Specifically, this includes:
- Single Sign-On (SSO), with support for both SAML and OIDC
- Role-based access control
- Document permission inheritance from connected sources
- Usage analytics and query history accessible to admins
- Whitelabeling
- API key authentication
- Encryption of secrets
- And many more! Checkout [our website](https://www.onyx.app/) for the latest.
To try the Onyx Enterprise Edition:
1. Checkout [Onyx Cloud](https://cloud.onyx.app/signup).
2. For self-hosting the Enterprise Edition, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
1. Checkout our [Cloud product](https://cloud.onyx.app/signup).
2. For self-hosting, contact us at [founders@onyx.app](mailto:founders@onyx.app) or book a call with us on our [Cal](https://cal.com/team/onyx/founders).
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
## ⭐Star History
[![Star History Chart](https://api.star-history.com/svg?repos=onyx-dot-app/onyx&type=Date)](https://star-history.com/#onyx-dot-app/onyx&Date)

View File

@@ -1,27 +0,0 @@
"""Add indexes to document__tag
Revision ID: 1a03d2c2856b
Revises: 9c00a2bccb83
Create Date: 2025-02-18 10:45:13.957807
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "1a03d2c2856b"
down_revision = "9c00a2bccb83"
branch_labels: None = None
depends_on: None = None
def upgrade() -> None:
op.create_index(
op.f("ix_document__tag_tag_id"),
"document__tag",
["tag_id"],
unique=False,
)
def downgrade() -> None:
op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag")

View File

@@ -1,32 +0,0 @@
"""add index
Revision ID: 8f43500ee275
Revises: da42808081e3
Create Date: 2025-02-24 17:35:33.072714
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8f43500ee275"
down_revision = "da42808081e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
# op.execute(
# """
# CREATE INDEX idx_chat_message_message_lower
# ON chat_message (LOWER(substring(message, 1, 1500)))
# """
# )
pass
def downgrade() -> None:
# Drop the index
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -1,43 +0,0 @@
"""chat_message_agentic
Revision ID: 9c00a2bccb83
Revises: b7a7eee5aa15
Create Date: 2025-02-17 11:15:43.081150
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "9c00a2bccb83"
down_revision = "b7a7eee5aa15"
branch_labels = None
depends_on = None
def upgrade() -> None:
# First add the column as nullable
op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True))
# Update existing rows based on presence of SubQuestions
op.execute(
"""
UPDATE chat_message
SET is_agentic = EXISTS (
SELECT 1
FROM agent__sub_question
WHERE agent__sub_question.primary_question_id = chat_message.id
)
WHERE is_agentic IS NULL
"""
)
# Make the column non-nullable with a default value of False
op.alter_column(
"chat_message", "is_agentic", nullable=False, server_default=sa.text("false")
)
def downgrade() -> None:
op.drop_column("chat_message", "is_agentic")

View File

@@ -1,29 +0,0 @@
"""remove inactive ccpair status on downgrade
Revision ID: acaab4ef4507
Revises: b388730a2899
Create Date: 2025-02-16 18:21:41.330212
"""
from alembic import op
from onyx.db.models import ConnectorCredentialPair
from onyx.db.enums import ConnectorCredentialPairStatus
from sqlalchemy import update
# revision identifiers, used by Alembic.
revision = "acaab4ef4507"
down_revision = "b388730a2899"
branch_labels = None
depends_on = None
def upgrade() -> None:
pass
def downgrade() -> None:
op.execute(
update(ConnectorCredentialPair)
.where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID)
.values(status=ConnectorCredentialPairStatus.ACTIVE)
)

View File

@@ -1,31 +0,0 @@
"""nullable preferences
Revision ID: b388730a2899
Revises: 1a03d2c2856b
Create Date: 2025-02-17 18:49:22.643902
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "b388730a2899"
down_revision = "1a03d2c2856b"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column("user", "temperature_override_enabled", nullable=True)
op.alter_column("user", "auto_scroll", nullable=True)
def downgrade() -> None:
# Ensure no null values before making columns non-nullable
op.execute(
'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL'
)
op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL')
op.alter_column("user", "temperature_override_enabled", nullable=False)
op.alter_column("user", "auto_scroll", nullable=False)

View File

@@ -1,120 +0,0 @@
"""migrate jira connectors to new format
Revision ID: da42808081e3
Revises: f13db29f3101
Create Date: 2025-02-24 11:24:54.396040
"""
from alembic import op
import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.
revision = "da42808081e3"
down_revision = "f13db29f3101"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config
for connector_id, old_config in jira_connectors:
if not old_config:
continue
# Extract project key from URL if it exists
new_config: dict[str, str | None] = {}
if project_url := old_config.get("jira_project_url"):
# Parse the URL to get base and project
try:
jira_base, project_key = extract_jira_project(project_url)
new_config = {"jira_base_url": jira_base, "project_key": project_key}
except ValueError:
# If URL parsing fails, just use the URL as the base
new_config = {
"jira_base_url": project_url.split("/projects/")[0],
"project_key": None,
}
else:
# For connectors without a project URL, we need admin intervention
# Mark these for review
print(
f"WARNING: Jira connector {connector_id} has no project URL configured"
)
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :id
"""
),
{"id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config back to the old format
for connector_id, new_config in jira_connectors:
if not new_config:
continue
old_config = {}
base_url = new_config.get("jira_base_url")
project_key = new_config.get("project_key")
if base_url and project_key:
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
elif base_url:
old_config = {"jira_project_url": base_url}
else:
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :old_config
WHERE id = :id
"""
),
{"id": connector_id, "old_config": old_config},
)

View File

@@ -1,27 +0,0 @@
"""Add composite index for last_modified and last_synced to document
Revision ID: f13db29f3101
Revises: b388730a2899
Create Date: 2025-02-18 22:48:11.511389
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f13db29f3101"
down_revision = "acaab4ef4507"
branch_labels: str | None = None
depends_on: str | None = None
def upgrade() -> None:
op.create_index(
"ix_document_sync_status",
"document",
["last_modified", "last_synced"],
unique=False,
)
def downgrade() -> None:
op.drop_index("ix_document_sync_status", table_name="document")

View File

@@ -5,9 +5,11 @@ from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -16,8 +18,10 @@ logger = setup_logger()
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@@ -31,19 +35,24 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str) -> None:
def check_ttl_management_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@celery_app.task(
@@ -51,9 +60,9 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,

View File

@@ -18,7 +18,7 @@ logger = setup_logger()
def monitor_usergroup_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
"""This function is likely to move in the worker refactor happening next."""
fence_key = key_bytes.decode("utf-8")

View File

@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import UserGroup__ConnectorCredentialPair
@@ -36,11 +35,10 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
def get_cc_pairs_by_source(
db_session: Session,
source_type: DocumentSource,
access_type: AccessType | None = None,
status: ConnectorCredentialPairStatus | None = None,
only_sync: bool,
) -> list[ConnectorCredentialPair]:
"""
Get all cc_pairs for a given source type with optional filtering by access_type and status
Get all cc_pairs for a given source type (and optionally only sync)
result is sorted by cc_pair id
"""
query = (
@@ -50,11 +48,8 @@ def get_cc_pairs_by_source(
.order_by(ConnectorCredentialPair.id)
)
if access_type is not None:
query = query.filter(ConnectorCredentialPair.access_type == access_type)
if status is not None:
query = query.filter(ConnectorCredentialPair.status == status)
if only_sync:
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
cc_pairs = query.all()
return cc_pairs

View File

@@ -14,24 +14,30 @@ def _build_group_member_email_map(
confluence_client: OnyxConfluence, cc_pair_id: int
) -> dict[str, set[str]]:
group_member_emails: dict[str, set[str]] = {}
for user in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user}")
for user_result in confluence_client.paginated_cql_user_retrieval():
logger.debug(f"Processing groups for user: {user_result}")
email = user.email
user = user_result.get("user", {})
if not user:
msg = f"user result missing user field: {user_result}"
emit_background_error(msg, cc_pair_id=cc_pair_id)
logger.error(msg)
continue
email = user.get("email")
if not email:
# This field is only present in Confluence Server
user_name = user.username
user_name = user.get("username")
# If it is present, try to get the email using a Server-specific method
if user_name:
email = get_user_email_from_username__server(
confluence_client=confluence_client,
user_name=user_name,
)
if not email:
# If we still don't have an email, skip this user
msg = f"user result missing email field: {user}"
if user.type == "app":
msg = f"user result missing email field: {user_result}"
if user.get("type") == "app":
logger.warning(msg)
else:
emit_background_error(msg, cc_pair_id=cc_pair_id)
@@ -39,7 +45,7 @@ def _build_group_member_email_map(
continue
all_users_groups: set[str] = set()
for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id):
for group in confluence_client.paginated_groups_by_user_retrieval(user):
# group name uniqueness is enforced by Confluence, so we can use it as a group ID
group_id = group["name"]
group_member_emails.setdefault(group_id, set()).add(email)

View File

@@ -62,14 +62,12 @@ def _fetch_permissions_for_permission_ids(
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain)",
supportsAllDrives=True,
continue_on_404_or_403=True,
)
permissions_for_doc_id = []
@@ -106,13 +104,7 @@ def _get_permissions_from_slim_doc(
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
skipped_permissions = 0
for permission in permissions_list:
if not permission:
skipped_permissions += 1
continue
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
@@ -129,11 +121,6 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
if skipped_permissions > 0:
logger.warning(
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
)
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())

View File

@@ -33,7 +33,7 @@ def add_tenant_id_middleware(app: FastAPI, logger: logging.LoggerAdapter) -> Non
return await call_next(request)
except Exception as e:
logger.exception(f"Error in tenant ID middleware: {str(e)}")
logger.error(f"Error in tenant ID middleware: {str(e)}")
raise
@@ -49,7 +49,7 @@ async def _get_tenant_id_from_request(
"""
# Check for API key
tenant_id = extract_tenant_from_api_key_header(request)
if tenant_id is not None:
if tenant_id:
return tenant_id
# Check for anonymous user cookie

View File

@@ -36,12 +36,12 @@ from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -271,12 +271,12 @@ def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
tenant_id = get_current_tenant_id()
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
@@ -329,6 +329,7 @@ def handle_slack_oauth_callback(
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
@@ -336,7 +337,7 @@ def handle_slack_oauth_callback(
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
@@ -522,6 +523,7 @@ def handle_google_drive_oauth_callback(
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
@@ -529,7 +531,7 @@ def handle_google_drive_oauth_callback(
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (

View File

@@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def _check_token_rate_limits(user: User | None) -> None:
def _check_token_rate_limits(user: User | None, tenant_id: str | None) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)
elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)
else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
(_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, (tenant_id,)),
]
)
@@ -52,8 +52,8 @@ User rate limits
"""
def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
def _user_is_rate_limited(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
@@ -93,8 +93,8 @@ User Group rate limits
"""
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:

View File

@@ -2,7 +2,6 @@ import csv
import io
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from uuid import UUID
from fastapi import APIRouter
@@ -22,10 +21,8 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
@@ -38,8 +35,6 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
def fetch_and_process_chat_session_history(
db_session: Session,
@@ -112,17 +107,6 @@ def get_user_chat_sessions(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
# we specifically don't allow this endpoint if "anonymized" since
# this is a direct query on the user id
if ONYX_QUERY_HISTORY_TYPE in [
QueryHistoryType.DISABLED,
QueryHistoryType.ANONYMIZED,
]:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Per user query history has been disabled by the administrator.",
)
try:
chat_sessions = get_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session, limit=0
@@ -157,12 +141,6 @@ def get_chat_session_history(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
@@ -179,16 +157,11 @@ def get_chat_session_history(
feedback_filter=feedback_type,
)
minimal_chat_sessions: list[ChatSessionMinimal] = []
for chat_session in page_of_chat_sessions:
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
minimal_chat_sessions.append(minimal_chat_session)
return PaginatedReturn(
items=minimal_chat_sessions,
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
total_items=total_filtered_chat_sessions_count,
)
@@ -199,12 +172,6 @@ def get_chat_session_admin(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionSnapshot:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
try:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
@@ -226,9 +193,6 @@ def get_chat_session_admin(
f"Could not create snapshot for chat session with id '{chat_session_id}'",
)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
return snapshot
@@ -239,12 +203,6 @@ def get_query_history_as_csv(
end: datetime | None = None,
db_session: Session = Depends(get_session),
) -> StreamingResponse:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
@@ -255,9 +213,6 @@ def get_query_history_as_csv(
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
for chat_session_snapshot in complete_chat_session_history:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
question_answer_pairs.extend(
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
)

View File

@@ -41,15 +41,14 @@ from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
@@ -58,14 +57,13 @@ router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
tenant_id: str | None = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@@ -74,15 +72,15 @@ async def get_anonymous_user_path_api(
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
tenant_id: str = Depends(get_current_tenant_id),
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
@@ -103,7 +101,7 @@ async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
@@ -152,17 +150,14 @@ async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
return fetch_billing_information(CURRENT_TENANT_ID_CONTEXTVAR.get())
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
async def create_customer_portal_session(_: User = Depends(current_admin_user)) -> dict:
try:
# Fetch tenant_id and current tenant's information
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
@@ -186,8 +181,6 @@ async def create_subscription_session(
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
@@ -204,7 +197,7 @@ async def impersonate_user(
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
with get_session_with_tenant(tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
@@ -228,9 +221,8 @@ async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str = Depends(get_current_tenant_id),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"

View File

@@ -118,7 +118,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
@@ -134,7 +134,7 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_tenant(tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
@@ -224,7 +224,7 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
default_model_name="claude-3-5-sonnet-20241022",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)

View File

@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(POSTGRES_DEFAULT_SCHEMA) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()

View File

@@ -98,17 +98,12 @@ class CloudEmbedding:
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
f"Error embedding text with OpenAI: {str(e)} \n"
f"Model: {model} \n"
f"Provider: {self.provider} \n"
f"Texts: {texts}"
)
logger.error(error_string)
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
from shared_configs.configs import MULTI_TENANT
_API_KEY_HEADER_NAME = "Authorization"
@@ -36,7 +35,8 @@ class ApiKeyDescriptor(BaseModel):
def generate_api_key(tenant_id: str | None = None) -> str:
if not MULTI_TENANT or not tenant_id:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID

View File

@@ -2,8 +2,6 @@ import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
from email.utils import make_msgid
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
@@ -12,10 +10,8 @@ from onyx.configs.app_configs import SMTP_PORT
from onyx.configs.app_configs import SMTP_SERVER
from onyx.configs.app_configs import SMTP_USER
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
@@ -153,9 +149,8 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
if mail_from:
msg["From"] = mail_from
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
@@ -177,7 +172,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
subject = "Your Onyx Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>We're sorry to see you go.</p>"
"<p>Were sorry to see you go.</p>"
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
"<p>If you change your mind, you can always come back!</p>"
)
@@ -192,64 +187,36 @@ def send_subscription_cancellation_email(user_email: str) -> None:
send_email(user_email, subject, html_content, text_content)
def send_user_email_invite(
user_email: str, current_user: User, auth_type: AuthType
) -> None:
def send_user_email_invite(user_email: str, current_user: User) -> None:
subject = "Invitation to Join Onyx Organization"
heading = "You've Been Invited!"
# the exact action taken by the user, and thus the message, depends on the auth type
message = f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
if auth_type == AuthType.CLOUD:
message += (
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
elif auth_type == AuthType.BASIC:
message += (
"<p>To join the organization, please click the button below to set a password "
"and complete your registration.</p>"
)
elif auth_type == AuthType.GOOGLE_OAUTH:
message += (
"<p>To join the organization, please click the button below to login with Google "
"and complete your registration.</p>"
)
elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML:
message += (
"<p>To join the organization, please click the button below to"
" complete your registration.</p>"
)
else:
raise ValueError(f"Invalid auth type: {auth_type}")
message = (
f"<p>You have been invited by {current_user.email} to join an organization on Onyx.</p>"
"<p>To join the organization, please click the button below to set a password "
"or login with Google and complete your registration.</p>"
)
cta_text = "Join Organization"
cta_link = f"{WEB_DOMAIN}/auth/signup?email={user_email}"
html_content = build_html_email(heading, message, cta_text, cta_link)
# text content is the fallback for clients that don't support HTML
# not as critical, so not having special cases for each auth type
text_content = (
f"You have been invited by {current_user.email} to join an organization on Onyx.\n"
"To join the organization, please visit the following link:\n"
f"{WEB_DOMAIN}/auth/signup?email={user_email}\n"
"You'll be asked to set a password or login with Google to complete your registration."
)
if auth_type == AuthType.CLOUD:
text_content += "You'll be asked to set a password or login with Google to complete your registration."
send_email(user_email, subject, html_content, text_content)
def send_forgot_password_email(
user_email: str,
token: str,
tenant_id: str,
mail_from: str = EMAIL_FROM,
tenant_id: str | None = None,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
if tenant_id:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)

View File

@@ -42,5 +42,4 @@ def fetch_no_auth_user(
role=UserRole.BASIC if anonymous_user_enabled else UserRole.ADMIN,
preferences=load_no_auth_user_preferences(store),
is_anonymous_user=anonymous_user_enabled,
password_configured=False,
)

View File

@@ -1,7 +1,5 @@
import json
import random
import secrets
import string
import uuid
from collections.abc import AsyncGenerator
from datetime import datetime
@@ -88,6 +86,7 @@ from onyx.db.auth import get_user_db
from onyx.db.auth import SQLAlchemyUserAdminDB
from onyx.db.engine import get_async_session
from onyx.db.engine import get_async_session_with_tenant
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
@@ -105,7 +104,6 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -141,30 +139,6 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
return email or ""
def generate_password() -> str:
lowercase_letters = string.ascii_lowercase
uppercase_letters = string.ascii_uppercase
digits = string.digits
special_characters = string.punctuation
# Ensure at least one of each required character type
password = [
secrets.choice(uppercase_letters),
secrets.choice(digits),
secrets.choice(special_characters),
]
# Fill the rest with a mix of characters
remaining_length = 12 - len(password)
all_characters = lowercase_letters + uppercase_letters + digits + special_characters
password.extend(secrets.choice(all_characters) for _ in range(remaining_length))
# Shuffle the password to randomize the position of the required characters
random.shuffle(password)
return "".join(password)
def user_needs_to_be_verified() -> bool:
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
return REQUIRE_EMAIL_VERIFICATION
@@ -214,8 +188,8 @@ def verify_email_is_invited(email: str) -> None:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
with get_session_with_tenant(tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -420,7 +394,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = cast(User, await self.user_db.get_by_email(account_email))
user = await self.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
@@ -553,7 +527,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async_return_default_schema,
)(email=user.email)
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None
@@ -617,39 +591,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
return user
async def reset_password_as_admin(self, user_id: uuid.UUID) -> str:
"""Admin-only. Generate a random password for a user and return it."""
user = await self.get(user_id)
new_password = generate_password()
await self._update(user, {"password": new_password})
return new_password
async def change_password_if_old_matches(
self, user: User, old_password: str, new_password: str
) -> None:
"""
For normal users to change password if they know the old one.
Raises 400 if old password doesn't match.
"""
verified, updated_password_hash = self.password_helper.verify_and_update(
old_password, user.hashed_password
)
if not verified:
# Raise some HTTPException (or your custom exception) if old password is invalid:
from fastapi import HTTPException, status
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid current password",
)
# If the hash was upgraded behind the scenes, we can keep it before setting the new password:
if updated_password_hash:
user.hashed_password = updated_password_hash
# Now apply and validate the new password
await self._update(user, {"password": new_password})
async def get_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_user_db),
@@ -874,9 +815,8 @@ async def current_limited_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> User | None:
tenant_id = get_current_tenant_id()
return await double_check_user(
user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id)
)

View File

@@ -2,7 +2,6 @@ import logging
import multiprocessing
import time
from typing import Any
from typing import cast
import sentry_sdk
from celery import Task
@@ -34,7 +33,6 @@ from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGrou
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import ColoredFormatter
from onyx.utils.logger import PlainFormatter
@@ -60,35 +58,13 @@ else:
logger.debug("Sentry DSN not provided, skipping Sentry initialization")
class TenantAwareTask(Task):
"""A custom base Task that sets tenant_id in a contextvar before running."""
abstract = True # So Celery knows not to register this as a real task.
def __call__(self, *args: Any, **kwargs: Any) -> Any:
# Grab tenant_id from the kwargs, or fallback to default if missing.
tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA
# Set the context var
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Actually run the task now
try:
return super().__call__(*args, **kwargs)
finally:
# Clear or reset after the task runs
# so it does not leak into any subsequent tasks on the same worker process
CURRENT_TENANT_ID_CONTEXTVAR.set(None)
@task_prerun.connect
def on_task_prerun(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
**kwds: Any,
) -> None:
pass
@@ -132,9 +108,9 @@ def on_task_postrun(
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = POSTGRES_DEFAULT_SCHEMA
tenant_id = None
else:
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
@@ -225,7 +201,7 @@ def wait_for_redis(sender: Any, **kwargs: Any) -> None:
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_shared_redis_client()
r = get_redis_client(tenant_id=None)
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
@@ -311,7 +287,7 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5
WAIT_LIMIT = 60
r = get_shared_redis_client()
r = get_redis_client(tenant_id=None)
time_start = time.monotonic()
logger.info("Waiting for primary worker to be ready...")
@@ -463,6 +439,24 @@ class TenantContextFilter(logging.Filter):
return True
@task_prerun.connect
def set_tenant_id(
sender: Any | None = None,
task_id: str | None = None,
task: Task | None = None,
args: tuple[Any, ...] | None = None,
kwargs: dict[str, Any] | None = None,
**other_kwargs: Any,
) -> None:
"""Signal handler to set tenant ID in context var before task starts."""
tenant_id = (
kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if kwargs
else POSTGRES_DEFAULT_SCHEMA
)
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@task_postrun.connect
def reset_tenant_id(
sender: Any | None = None,

View File

@@ -132,7 +132,6 @@ class DynamicTenantScheduler(PersistentScheduler):
f"Adding options to task {tenant_task_name}: {options}"
)
tenant_task["options"] = options
new_schedule[tenant_task_name] = tenant_task
return new_schedule
@@ -257,4 +256,3 @@ def on_setup_logging(
celery_app.conf.beat_scheduler = DynamicTenantScheduler
celery_app.conf.task_default_base = app_base.TenantAwareTask

View File

@@ -20,7 +20,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.heavy")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -21,7 +21,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.indexing")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -23,7 +23,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.light")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -20,7 +20,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.monitoring")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect

View File

@@ -24,7 +24,7 @@ from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_default_tenant
from onyx.db.engine import SqlEngine
from onyx.db.index_attempt import get_index_attempt
from onyx.db.index_attempt import mark_attempt_canceled
@@ -38,7 +38,7 @@ from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_pool import get_shared_redis_client
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_usergroup import RedisUserGroup
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -47,7 +47,6 @@ logger = setup_logger()
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.primary")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
@signals.task_prerun.connect
@@ -102,7 +101,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
# This is singleton work that should be done on startup exactly once
# by the primary worker. This is unnecessary in the multi tenant scenario
r = get_shared_redis_client()
r = get_redis_client(tenant_id=None)
# Log the role and slave count - being connected to a slave or slave count > 0 could be problematic
info: dict[str, Any] = cast(dict, r.info("replication"))
@@ -159,7 +158,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisConnectorExternalGroupSync.reset_all(r)
# mark orphaned index attempts as failed
with get_session_with_current_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(db_session, r)
for attempt_id in unfenced_attempt_ids:
attempt = get_index_attempt(db_session, attempt_id)
@@ -235,7 +234,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
lock: RedisLock = worker.primary_worker_lock
r = get_shared_redis_client()
r = get_redis_client(tenant_id=None)
if lock.owned():
task_logger.debug("Reacquiring primary worker lock.")

View File

@@ -92,8 +92,7 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue and return them
as a set.
"""This is a redis specific way to build a list of tasks in a queue.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.

View File

@@ -34,7 +34,7 @@ def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
tenant_id: str | None = None,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
tenant_id: str | None = None,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id

View File

@@ -8,21 +8,16 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
@@ -32,7 +27,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -57,51 +52,6 @@ class TaskDependencyError(RuntimeError):
with connector deletion."""
def revoke_tasks_blocking_deletion(
redis_connector: RedisConnector, db_session: Session, app: Celery
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
try:
permissions_sync_payload = redis_connector.permissions.payload
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
app.control.revoke(permissions_sync_payload.celery_task_id)
task_logger.info(
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking pruning task")
try:
prune_payload = redis_connector.prune.payload
if prune_payload and prune_payload.celery_task_id:
app.control.revoke(prune_payload.celery_task_id)
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
except Exception:
task_logger.exception("Exception while revoking permissions sync task")
try:
external_group_sync_payload = redis_connector.external_group_sync.payload
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
app.control.revoke(external_group_sync_payload.celery_task_id)
task_logger.info(
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking external group sync task")
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
@@ -109,46 +59,32 @@ def revoke_tasks_blocking_deletion(
trail=False,
bind=True,
)
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# Prevent this task from overlapping with itself
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating connector deletion fences"
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair in cc_pairs:
cc_pair_ids.append(cc_pair.id)
# try running cleanup on the cc_pair_ids
for cc_pair_id in cc_pair_ids:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
redis_connector = RedisConnector(tenant_id, cc_pair_id)
try:
try_generate_document_cc_pair_cleanup_tasks(
@@ -156,38 +92,9 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# on the first error, we set a stop signal and revoke the dependent tasks
# on subsequent errors, we hard reset blocking fences after our specified timeout
# is exceeded
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
if not redis_connector.stop.fenced:
# one time revoke of celery tasks
task_logger.info("Revoking any tasks blocking deletion.")
revoke_tasks_blocking_deletion(
redis_connector, db_session, self.app
)
redis_connector.stop.set_fence(True)
redis_connector.stop.set_timeout()
else:
# stop signal already set
if redis_connector.stop.timed_out:
# waiting too long, just reset blocking fences
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()
else:
# just wait
pass
redis_connector.stop.set_fence(True)
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
@@ -222,7 +129,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
cc_pair_id: int,
db_session: Session,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
@@ -262,7 +169,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
redis_connector.delete.set_active()
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
@@ -343,7 +249,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
def monitor_connector_deletion_taskset(
tenant_id: str, key_bytes: bytes, r: Redis
tenant_id: str | None, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -371,7 +277,7 @@ def monitor_connector_deletion_taskset(
f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}"
)
if remaining > 0:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -381,7 +287,7 @@ def monitor_connector_deletion_taskset(
)
return
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -495,171 +401,3 @@ def monitor_connector_deletion_taskset(
)
redis_connector.delete.reset()
def validate_connector_deletion_fences(
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# validate all existing connector deletion jobs
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
continue
validate_connector_deletion_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
r,
)
lock_beat.reacquire()
return
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.delete.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.delete.payload
except ValidationError:
task_logger.exception(
"validate_connector_deletion_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return
if not payload:
return
# OK, there's actually something for us to validate
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.delete.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_connector_deletion_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.delete.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.delete.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_connector_deletion_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return

View File

@@ -30,7 +30,6 @@ from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -43,12 +42,10 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -66,7 +63,6 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
@@ -123,13 +119,13 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None:
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# TODO(rkuo): merge into check function after lookup table for fences is added
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -144,7 +140,7 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
try:
# get all cc pairs that need to be synced
cc_pair_ids_to_sync: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
for cc_pair in cc_pairs:
@@ -193,23 +189,16 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
task_logger.info(f"check_for_doc_permissions_sync finished: tenant={tenant_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id} {error_msg}"
)
task_logger.exception(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id}"
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
@@ -221,7 +210,7 @@ def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Returns a randomized payload id on success.
Returns None if no syncing is required."""
@@ -258,7 +247,7 @@ def try_creating_permissions_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
@@ -293,19 +282,13 @@ def try_creating_permissions_sync_task(
redis_connector.permissions.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_permissions_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_permissions_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -320,7 +303,7 @@ def try_creating_permissions_sync_task(
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
@@ -338,7 +321,7 @@ def connector_permission_sync_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -395,7 +378,7 @@ def connector_permission_sync_generator_task(
return None
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -405,29 +388,6 @@ def connector_permission_sync_generator_task(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -479,10 +439,6 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
)
task_logger.exception(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
@@ -509,7 +465,7 @@ def connector_permission_sync_generator_task(
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str,
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
@@ -517,8 +473,6 @@ def update_external_document_permissions_task(
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@@ -526,8 +480,7 @@ def update_external_document_permissions_task(
external_access = document_external_access.external_access
try:
with get_session_with_current_tenant() as db_session:
# Add the users to the DB if they don't exist
with get_session_with_tenant(tenant_id) as db_session:
batch_add_ext_perm_user_if_not_exists(
db_session=db_session,
emails=list(external_access.external_user_emails),
@@ -558,33 +511,18 @@ def update_external_document_permissions_task(
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
)
except Exception:
task_logger.exception(
f"update_external_document_permissions_task exceptioned: "
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} doc_id={doc_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
finally:
task_logger.info(
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
)
return True
def validate_permission_sync_fences(
tenant_id: str,
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -631,7 +569,7 @@ def validate_permission_sync_fences(
def validate_permission_sync_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
@@ -841,7 +779,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
def monitor_ccpair_permissions_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)

View File

@@ -37,12 +37,9 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
@@ -58,7 +55,6 @@ from onyx.redis.redis_connector_ext_group_sync import (
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -123,11 +119,11 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client()
r_replica = get_redis_replica_client()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -144,7 +140,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
try:
cc_pair_ids_to_sync: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
# We only want to sync one cc_pair per source type in
@@ -152,10 +148,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session,
source,
access_type=AccessType.SYNC,
status=ConnectorCredentialPairStatus.ACTIVE,
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
@@ -202,17 +195,12 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_external_group_sync exception: tenant={tenant_id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_external_group_sync finished: tenant={tenant_id}")
return True
@@ -220,7 +208,7 @@ def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
@@ -242,7 +230,7 @@ def try_creating_external_group_sync_task(
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
insert_sync_record(
db_session=db_session,
entity_id=cc_pair_id,
@@ -279,19 +267,12 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_external_group_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
task_logger.info(
f"try_creating_external_group_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -306,7 +287,7 @@ def try_creating_external_group_sync_task(
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""
External group sync task for a given connector credential pair
@@ -315,7 +296,7 @@ def connector_external_group_sync_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -376,40 +357,16 @@ def connector_external_group_sync_generator_task(
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
eager_load_credential=True,
)
if cc_pair is None:
raise ValueError(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -421,23 +378,12 @@ def connector_external_group_sync_generator_task(
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
)
logger.debug(f"New external user groups: {external_user_groups}")
replace_user__ext_group_for_cc_pair(
db_session=db_session,
@@ -458,19 +404,11 @@ def connector_external_group_sync_generator_task(
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id} {error_msg}"
)
task_logger.exception(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
)
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
update_sync_record_status(
db_session=db_session,
entity_id=cc_pair_id,
@@ -493,7 +431,7 @@ def connector_external_group_sync_generator_task(
def validate_external_group_sync_fences(
tenant_id: str,
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_replica: Redis,
@@ -521,11 +459,12 @@ def validate_external_group_sync_fences(
)
lock_beat.reacquire()
return
def validate_external_group_sync_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,

View File

@@ -41,18 +41,16 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingMode
from onyx.db.enums import IndexingStatus
from onyx.db.index_attempt import get_index_attempt
@@ -92,9 +90,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
SUCCEEDED = "succeeded"
SPAWN_FAILED = "spawn_failed" # connector spawn failed
SPAWN_NOT_ALIVE = (
"spawn_not_alive" # spawn succeeded but process did not come alive
)
BLOCKED_BY_DELETION = "blocked_by_deletion"
BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal"
@@ -108,9 +103,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
"index_attempt_mismatch" # expected index attempt metadata not found in db
)
CONNECTOR_VALIDATION_ERROR = (
"connector_validation_error" # the connector validation failed
)
CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned
WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned
@@ -120,8 +112,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
# the watchdog terminated the task due to no activity
TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout"
# NOTE: this may actually be the same as SIGKILL, but parsed differently by python
# consolidate once we know more
OUT_OF_MEMORY = "out_of_memory"
PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill"
@@ -131,7 +121,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
_ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = {
IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9,
IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137,
IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247,
IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248,
IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249,
IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250,
@@ -148,8 +137,6 @@ class IndexingWatchdogTerminalStatus(str, Enum):
def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus":
_CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = {
-9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL,
137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY,
247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR,
248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION,
249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL,
250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND,
@@ -182,7 +169,7 @@ class SimpleJobResult:
class ConnectorIndexingContext(BaseModel):
tenant_id: str
tenant_id: str | None
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
@@ -210,7 +197,7 @@ class ConnectorIndexingLogBuilder:
def monitor_ccpair_indexing_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
@@ -358,16 +345,15 @@ def monitor_ccpair_indexing_taskset(
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
time_start = time.monotonic()
tasks_created = 0
locked = False
redis_client = get_redis_client()
redis_client_replica = get_redis_replica_client()
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client_replica = get_redis_replica_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
@@ -405,7 +391,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# 1/3: KICKOFF
# check for search settings swap
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
@@ -426,7 +412,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# gather cc_pair_ids
lock_beat.reacquire()
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = fetch_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
@@ -436,7 +422,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
lock_beat.reacquire()
redis_connector = RedisConnector(tenant_id, cc_pair_id)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
redis_connector_index = redis_connector.new_index(
@@ -514,7 +500,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# Fail any index attempts in the DB that don't have fences
# This shouldn't ever happen!
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
unfenced_attempt_ids = get_unfenced_index_attempt_ids(
db_session, redis_client
)
@@ -566,7 +552,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_indexing_taskset(
tenant_id, key_bytes, redis_client_replica, db_session
)
@@ -597,8 +583,8 @@ def connector_indexing_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str | None,
is_ee: bool,
tenant_id: str,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -649,7 +635,7 @@ def connector_indexing_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
if redis_connector.delete.fenced:
raise SimpleJobException(
@@ -743,7 +729,7 @@ def connector_indexing_task(
redis_connector_index.set_fence(payload)
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
attempt = get_index_attempt(db_session, index_attempt_id)
if not attempt:
raise SimpleJobException(
@@ -778,9 +764,9 @@ def connector_indexing_task(
callback = IndexingCallback(
os.getppid(),
redis_connector,
redis_connector_index,
lock,
r,
redis_connector_index,
)
logger.info(
@@ -802,15 +788,6 @@ def connector_indexing_task(
# get back the total number of indexed docs and return it
n_final_progress = redis_connector_index.get_progress()
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
except ConnectorValidationError:
raise SimpleJobException(
f"Indexing task failed: attempt={index_attempt_id} "
f"tenant={tenant_id} "
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}",
code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code,
)
except Exception as e:
logger.exception(
f"Indexing spawned task failed: attempt={index_attempt_id} "
@@ -818,8 +795,8 @@ def connector_indexing_task(
f"cc_pair={cc_pair_id} "
f"search_settings={search_settings_id}"
)
raise e
raise e
finally:
if lock.owned():
lock.release()
@@ -890,7 +867,7 @@ def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
@@ -899,9 +876,6 @@ def connector_indexing_proxy_task(
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
start = time.monotonic()
@@ -927,18 +901,18 @@ def connector_indexing_proxy_task(
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting connector_indexing_task with tenant_id={tenant_id}")
job = client.submit(
connector_indexing_task,
index_attempt_id,
cc_pair_id,
search_settings_id,
global_version.is_ee_version(),
tenant_id,
global_version.is_ee_version(),
pure=False,
)
if not job or not job.process:
if not job:
result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED
task_logger.info(
log_builder.build(
@@ -949,39 +923,13 @@ def connector_indexing_proxy_task(
)
return
# Ensure the process has moved out of the starting state
num_waits = 0
while True:
if num_waits > 15:
result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE
task_logger.info(
log_builder.build(
"Indexing watchdog - finished",
status=str(result.status.value),
exit_code=str(result.exit_code),
)
)
job.release()
return
if job.process.is_alive() or job.process.exitcode is not None:
break
sleep(1)
num_waits += 1
task_logger.info(
log_builder.build(
"Indexing watchdog - spawn succeeded",
pid=str(job.process.pid),
)
)
task_logger.info(log_builder.build("Indexing watchdog - spawn succeeded"))
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
@@ -992,9 +940,6 @@ def connector_indexing_proxy_task(
index_attempt.connector_credential_pair.connector.source.value
)
redis_connector_index.set_active() # renew active signal
redis_connector_index.set_connector_active() # prime the connective active signal
while True:
sleep(5)
@@ -1020,7 +965,7 @@ def connector_indexing_proxy_task(
job.release()
break
# if a termination signal is detected, break (exit point will clean up)
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
@@ -1029,24 +974,10 @@ def connector_indexing_proxy_task(
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if activity timeout is detected, break (exit point will clean up)
if not redis_connector_index.connector_active():
task_logger.warning(
log_builder.build(
"Indexing watchdog - activity timeout exceeded",
timeout=f"{CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
)
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
break
# if the spawned task is still running, restart the check once again
# if the index attempt is not in a finished status
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
index_attempt = get_index_attempt(
db_session=db_session, index_attempt_id=index_attempt_id
)
@@ -1056,29 +987,25 @@ def connector_indexing_proxy_task(
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
except Exception:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
# No need to expose full stack trace for validation errors
result.exception_str = str(e)
else:
result.exception_str = traceback.format_exc()
result.exception_str = traceback.format_exc()
# handle exit and reporting
elapsed = time.monotonic() - start
if result.exception_str is not None:
# print with exception
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
@@ -1118,13 +1045,15 @@ def connector_indexing_proxy_task(
# print without exception
if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL:
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
mark_attempt_canceled(
index_attempt_id,
db_session,
"Connector termination signal detected",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
@@ -1132,25 +1061,6 @@ def connector_indexing_proxy_task(
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
@@ -1170,7 +1080,7 @@ def connector_indexing_proxy_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
)
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
"""Clean up old checkpoints that are older than 7 days."""
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)
@@ -1185,7 +1095,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
try:
locked = True
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
old_attempts = get_index_attempts_with_old_checkpoints(db_session)
for attempt in old_attempts:
task_logger.info(
@@ -1221,5 +1131,5 @@ def cleanup_checkpoint_task(
self: Task, *, index_attempt_id: int, tenant_id: str | None
) -> None:
"""Clean up a checkpoint for a given index attempt"""
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
cleanup_checkpoint(db_session, index_attempt_id)

View File

@@ -23,7 +23,7 @@ from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import IndexModelStatus
@@ -93,25 +93,27 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
return unfenced_attempts
class IndexingCallbackBase(IndexingHeartbeatInterface):
class IndexingCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_connector_index: RedisConnectorIndex,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.parent_pid = parent_pid
self.redis_connector: RedisConnector = redis_connector
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
self.redis_lock: RedisLock = redis_lock
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = f"{self.__class__.__name__}.__init__"
self.last_tag: str = "IndexingCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
@@ -125,8 +127,8 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
def progress(self, tag: str, amount: int) -> None:
# rkuo: this shouldn't be necessary yet because we spawn the process this runs inside
# with daemon=True. It seems likely some indexing tasks will need to spawn other processes
# eventually, which daemon=True prevents, so leave this code in until we're ready to test it.
# with daemon = True. It seems likely some indexing tasks will need to spawn other processes eventually
# so leave this code in until we're ready to test it.
# if self.parent_pid:
# # check if the parent pid is alive so we aren't running as a zombie
@@ -141,6 +143,8 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
# self.last_parent_check = now
try:
self.redis_connector.prune.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
@@ -152,7 +156,7 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
self.last_tag = tag
except LockError:
logger.exception(
f"{self.__class__.__name__} - lock.reacquire exceptioned: "
f"IndexingCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
@@ -163,31 +167,13 @@ class IndexingCallbackBase(IndexingHeartbeatInterface):
redis_lock_dump(self.redis_lock, self.redis_client)
raise
class IndexingCallback(IndexingCallbackBase):
def __init__(
self,
parent_pid: int,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
redis_connector_index: RedisConnectorIndex,
):
super().__init__(parent_pid, redis_connector, redis_lock, redis_client)
self.redis_connector_index: RedisConnectorIndex = redis_connector_index
def progress(self, tag: str, amount: int) -> None:
self.redis_connector_index.set_active()
self.redis_connector_index.set_connector_active()
super().progress(tag, amount)
self.redis_client.incrby(
self.redis_connector_index.generator_progress_key, amount
)
def validate_indexing_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
@@ -311,7 +297,7 @@ def validate_indexing_fence(
def validate_indexing_fences(
tenant_id: str,
tenant_id: str | None,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
@@ -332,7 +318,7 @@ def validate_indexing_fences(
if not key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
continue
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
validate_indexing_fence(
tenant_id,
key_bytes,
@@ -442,7 +428,7 @@ def try_creating_indexing_task(
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.

View File

@@ -8,7 +8,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import LLM_MODEL_UPDATE_API_URL
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import LLMProvider
@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
trail=False,
bind=True,
)
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
if not LLM_MODEL_UPDATE_API_URL:
raise ValueError("LLM model update API URL not configured")
@@ -75,7 +75,7 @@ def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
return None
# Then update the database with the fetched models
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# Get the default LLM provider
default_provider = (
db_session.query(LLMProvider)

View File

@@ -26,8 +26,7 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_db_current_time
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import IndexingStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -43,6 +42,7 @@ from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
_MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes
@@ -91,7 +91,7 @@ class Metric(BaseModel):
}
task_logger.info(json.dumps(data))
def emit(self, tenant_id: str) -> None:
def emit(self, tenant_id: str | None) -> None:
# Convert value to appropriate type based on the input value
bool_value = None
float_value = None
@@ -656,7 +656,7 @@ def build_job_id(
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
"""Collect and emit metrics about background processes.
This task runs periodically to gather metrics about:
- Queue lengths for different Celery queues
@@ -668,7 +668,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
task_logger.info("Starting background monitoring")
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
lock_monitoring: RedisLock = r.lock(
OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK,
@@ -683,7 +683,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
try:
# Get Redis client for Celery broker
redis_celery = self.app.broker_connection().channel().client # type: ignore
redis_std = get_redis_client()
redis_std = get_redis_client(tenant_id=tenant_id)
# Define metric collection functions and their dependencies
metric_functions: list[Callable[[], list[Metric]]] = [
@@ -693,7 +693,7 @@ def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
]
# Collect and log each metric
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
for metric_fn in metric_functions:
metrics = metric_fn()
for metric in metrics:
@@ -771,11 +771,12 @@ def cloud_check_alembic() -> bool | None:
if tenant_id is None:
continue
with get_session_with_shared_schema() as session:
with get_session_with_tenant(tenant_id=None) as session:
try:
result = session.execute(
text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1')
)
result_scalar: str | None = result.scalar_one_or_none()
if result_scalar is None:
raise ValueError("Alembic version should not be None.")
@@ -864,7 +865,7 @@ def cloud_monitor_celery_queues(
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
return monitor_celery_queues_helper(self)

View File

@@ -15,7 +15,7 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import PostgresAdvisoryLocks
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
@shared_task(
@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up
@@ -36,7 +36,7 @@ def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
ctx["deleted"] = 0
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# Exit the task if we can't take the advisory lock
result = db_session.execute(
text("SELECT pg_try_advisory_lock(:id)"),

View File

@@ -21,7 +21,7 @@ from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector
from onyx.background.celery.tasks.indexing.utils import IndexingCallbackBase
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
@@ -41,7 +41,7 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.document import get_documents_for_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
@@ -55,7 +55,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@@ -63,12 +62,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
class PruneCallback(IndexingCallbackBase):
def progress(self, tag: str, amount: int) -> None:
self.redis_connector.prune.set_active()
super().progress(tag, amount)
"""Jobs / utils for kicking off pruning tasks."""
@@ -114,9 +107,9 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
@@ -134,14 +127,14 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
# but pruning only kicks off once per hour
if not r.exists(OnyxRedisSignals.BLOCK_PRUNING):
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pairs = get_connector_credential_pairs(db_session)
for cc_pair_entry in cc_pairs:
cc_pair_ids.append(cc_pair_entry.id)
for cc_pair_id in cc_pair_ids:
lock_beat.reacquire()
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
@@ -189,20 +182,18 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
key_str = key_bytes.decode("utf-8")
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(f"Unexpected pruning check exception: {error_msg}")
except Exception:
task_logger.exception("Unexpected exception during pruning check")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_pruning finished: tenant={tenant_id}")
return True
@@ -211,7 +202,7 @@ def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
@@ -304,19 +295,13 @@ def try_creating_prune_generator_task(
redis_connector.prune.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_prune_generator_task exception: cc_pair={cc_pair.id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_prune_generator_task finished: cc_pair={cc_pair.id} payload_id={payload_id}"
)
return payload_id
@@ -333,7 +318,7 @@ def connector_pruning_generator_task(
cc_pair_id: int,
connector_id: int,
credential_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -352,7 +337,7 @@ def connector_pruning_generator_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
r = get_redis_client()
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
@@ -410,7 +395,7 @@ def connector_pruning_generator_task(
return None
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair(
db_session=db_session,
connector_id=connector_id,
@@ -440,7 +425,6 @@ def connector_pruning_generator_task(
f"cc_pair={cc_pair_id} "
f"connector_source={cc_pair.connector.source}"
)
runnable_connector = instantiate_connector(
db_session,
cc_pair.connector.source,
@@ -450,11 +434,12 @@ def connector_pruning_generator_task(
)
search_settings = get_current_search_settings(db_session)
redis_connector.new_index(search_settings.id)
redis_connector_index = redis_connector.new_index(search_settings.id)
callback = PruneCallback(
callback = IndexingCallback(
0,
redis_connector,
redis_connector_index,
lock,
r,
)
@@ -521,7 +506,7 @@ def connector_pruning_generator_task(
def monitor_ccpair_pruning_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -567,7 +552,7 @@ def monitor_ccpair_pruning_taskset(
def validate_pruning_fences(
tenant_id: str,
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -615,7 +600,7 @@ def validate_pruning_fences(
def validate_pruning_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
queued_tasks: set[str],

View File

@@ -32,7 +32,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
) -> int:
return self.index.delete_single(
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:

View File

@@ -1,5 +1,4 @@
import time
from enum import Enum
from http import HTTPStatus
import httpx
@@ -28,7 +27,7 @@ from onyx.db.document import mark_document_as_modified
from onyx.db.document import mark_document_as_synced
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.search_settings import get_active_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.document_index.interfaces import VespaDocumentFields
@@ -46,24 +45,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
class OnyxCeleryTaskCompletionStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SKIPPED = "skipped"
SOFT_TIME_LIMIT = "soft_time_limit"
NON_RETRYABLE_EXCEPTION = "non_retryable_exception"
RETRYABLE_EXCEPTION = "retryable_exception"
@shared_task(
name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
@@ -76,7 +57,7 @@ def document_by_cc_pair_cleanup_task(
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str,
tenant_id: str | None,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
@@ -97,10 +78,8 @@ def document_by_cc_pair_cleanup_task(
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
chunks_affected = 0
@@ -131,9 +110,6 @@ def document_by_cc_pair_cleanup_task(
db_session=db_session,
document_ids=[document_id],
)
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
elif count > 1:
action = "update"
@@ -177,11 +153,10 @@ def document_by_cc_pair_cleanup_task(
)
mark_document_as_synced(document_id, db_session)
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
else:
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
pass
db_session.commit()
elapsed = time.monotonic() - start
task_logger.info(
@@ -193,79 +168,57 @@ def document_by_cc_pair_cleanup_task(
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
return False
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"document_by_cc_pair_cleanup_task exceptioned: doc={document_id}"
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(f"Unexpected exception: doc={document_id}")
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"document_by_cc_pair_cleanup_task completed: status={completion_status.value} doc={document_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
self.retry(exc=e, countdown=countdown)
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_tenant(tenant_id) as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
return False
task_logger.info(f"document_by_cc_pair_cleanup_task finished: doc={document_id}")
return True
@@ -297,8 +250,7 @@ def cloud_beat_task_generator(
return None
last_lock_time = time.monotonic()
tenant_ids: list[str] = []
num_processed_tenants = 0
tenant_ids: list[str] | list[None] = []
try:
tenant_ids = get_all_tenant_ids()
@@ -326,8 +278,6 @@ def cloud_beat_task_generator(
expires=expires,
ignore_result=True,
)
num_processed_tenants += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -347,7 +297,6 @@ def cloud_beat_task_generator(
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)

View File

@@ -19,7 +19,6 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -35,7 +34,7 @@ from onyx.db.document_set import fetch_document_sets
from onyx.db.document_set import fetch_document_sets_for_document
from onyx.db.document_set import get_document_set_by_id
from onyx.db.document_set import mark_document_set_as_synced
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import DocumentSet
@@ -76,7 +75,7 @@ logger = setup_logger()
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
@@ -85,8 +84,8 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
time_start = time.monotonic()
r = get_redis_client()
r_replica = get_redis_replica_client()
r = get_redis_client(tenant_id=tenant_id)
r_replica = get_redis_replica_client(tenant_id=tenant_id)
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK,
@@ -99,7 +98,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
try:
# 1/3: KICKOFF
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_stale_document_sync_tasks(
self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id
)
@@ -107,7 +106,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
# region document set scan
lock_beat.reacquire()
document_set_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# check if any document sets are not synced
document_set_info = fetch_document_sets(
user_id=None, db_session=db_session, include_outdated=True
@@ -118,7 +117,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
for document_set_id in document_set_ids:
lock_beat.reacquire()
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_document_set_sync_tasks(
self.app, document_set_id, db_session, r, lock_beat, tenant_id
)
@@ -137,7 +136,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
pass
else:
usergroup_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
user_groups = fetch_user_groups(
db_session=db_session, only_up_to_date=False
)
@@ -147,7 +146,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
for usergroup_id in usergroup_ids:
lock_beat.reacquire()
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
try_generate_user_group_sync_tasks(
self.app, usergroup_id, db_session, r, lock_beat, tenant_id
)
@@ -168,7 +167,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
if key_str == RedisGlobalConnectorCredentialPair.FENCE_KEY:
monitor_connector_taskset(r)
elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX):
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_document_set_taskset(tenant_id, key_bytes, r, db_session)
elif key_str.startswith(RedisUserGroup.FENCE_PREFIX):
monitor_usergroup_taskset = (
@@ -178,7 +177,7 @@ def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
noop_fallback,
)
)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session)
except SoftTimeLimitExceeded:
@@ -208,7 +207,7 @@ def try_generate_stale_document_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -284,7 +283,7 @@ def try_generate_document_set_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -361,7 +360,7 @@ def try_generate_user_group_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -448,7 +447,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -523,13 +522,13 @@ def monitor_document_set_taskset(
time_limit=LIGHT_TIME_LIMIT,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
def vespa_metadata_sync_task(
self: Task, document_id: str, tenant_id: str | None
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
active_search_settings = get_active_search_settings(db_session)
doc_index = get_default_document_index(
search_settings=active_search_settings.primary,
@@ -541,103 +540,75 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
doc = get_document(document_id, db_session)
if not doc:
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=no_operation "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
else:
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
return False
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"vespa_metadata_sync_task exceptioned: doc={document_id}"
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"vespa_metadata_sync_task completed: status={completion_status.value} doc={document_id}"
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(
f"Unexpected exception during vespa metadata sync: doc={document_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -1,7 +1,5 @@
from sqlalchemy.exc import IntegrityError
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
def emit_background_error(
@@ -11,10 +9,5 @@ def emit_background_error(
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
with get_session_with_current_tenant() as db_session:
try:
create_background_error(db_session, message, cc_pair_id)
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
create_background_error(db_session, error_message, None)
with get_session_with_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)

View File

@@ -17,9 +17,6 @@ from typing import Optional
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -57,15 +54,6 @@ def _initializer(
kwargs = {}
logger.info("Initializing spawned worker child process.")
# 1. Get tenant_id from args or fallback to default
tenant_id = POSTGRES_DEFAULT_SCHEMA
for arg in reversed(args):
if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX):
tenant_id = arg
break
# 2. Set the tenant context before running anything
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Reset the engine in the child process
SqlEngine.reset_engine()
@@ -93,8 +81,6 @@ def _initializer(
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use 255 to indicate a generic exception
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _run_in_process(

View File

@@ -15,13 +15,11 @@ from onyx.background.indexing.memory_tracer import MemoryTracer
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD
from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE
from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -30,7 +28,7 @@ from onyx.connectors.models import IndexAttemptMetadata
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import create_index_attempt_error
from onyx.db.index_attempt import get_index_attempt
@@ -55,7 +53,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -68,6 +65,7 @@ def _get_connector_runner(
batch_size: int,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
) -> ConnectorRunner:
"""
@@ -86,12 +84,8 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
# validate the connector settings
if not INTEGRATION_TESTS_MODE:
runnable_connector.validate_connector_settings()
except Exception as e:
logger.exception(f"Unable to instantiate connector due to {e}")
@@ -240,7 +234,7 @@ def _check_failure_threshold(
def _run_indexing(
db_session: Session,
index_attempt_id: int,
tenant_id: str,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
@@ -250,7 +244,7 @@ def _run_indexing(
"""
start_time = time.monotonic() # jsut used for logging
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt_start = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt_start:
raise ValueError(
@@ -376,7 +370,7 @@ def _run_indexing(
document_count = 0
chunk_count = 0
try:
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
index_attempt = get_index_attempt(db_session_temp, index_attempt_id)
if not index_attempt:
raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.")
@@ -387,6 +381,7 @@ def _run_indexing(
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
# don't use a checkpoint if we're explicitly indexing from
@@ -435,7 +430,7 @@ def _run_indexing(
raise ConnectorStopSignal("Connector stop signal detected")
# TODO: should we move this into the above callback instead?
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
# will exception if the connector/index attempt is marked as paused/failed
_check_connector_and_attempt_status(
db_session_temp, ctx, index_attempt_id
@@ -444,7 +439,7 @@ def _run_indexing(
# save record of any failures at the connector level
if failure is not None:
total_failures += 1
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
create_index_attempt_error(
index_attempt_id,
ctx.cc_pair_id,
@@ -508,7 +503,7 @@ def _run_indexing(
if document.id not in failed_document_ids
]
for document_id in successful_document_ids:
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
if document_id in doc_id_to_unresolved_errors:
logger.info(
f"Resolving IndexAttemptError for document '{document_id}'"
@@ -521,7 +516,7 @@ def _run_indexing(
# add brand new failures
if index_pipeline_result.failures:
total_failures += len(index_pipeline_result.failures)
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
for failure in index_pipeline_result.failures:
create_index_attempt_error(
index_attempt_id,
@@ -538,7 +533,7 @@ def _run_indexing(
)
# This new value is updated every batch, so UI can refresh per batch update
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
# NOTE: Postgres uses the start of the transactions when computing `NOW()`
# so we need either to commit() or to use a new session
update_docs_indexed(
@@ -560,7 +555,7 @@ def _run_indexing(
check_checkpoint_size(checkpoint)
# save latest checkpoint
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
save_checkpoint(
db_session=db_session_temp,
index_attempt_id=index_attempt_id,
@@ -572,29 +567,9 @@ def _run_indexing(
"Connector run exceptioned after elapsed time: "
f"{time.monotonic() - start_time} seconds"
)
if isinstance(e, ConnectorValidationError):
# On validation errors during indexing, we want to cancel the indexing attempt
# and mark the CCPair as invalid. This prevents the connector from being
# used in the future until the credentials are updated.
with get_session_with_current_tenant() as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
reason=str(e),
)
if ctx.is_primary:
update_connector_credential_pair(
db_session=db_session_temp,
connector_id=ctx.connector_id,
credential_id=ctx.credential_id,
status=ConnectorCredentialPairStatus.INVALID,
)
memory_tracer.stop()
raise e
elif isinstance(e, ConnectorStopSignal):
with get_session_with_current_tenant() as db_session_temp:
if isinstance(e, ConnectorStopSignal):
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_canceled(
index_attempt_id,
db_session_temp,
@@ -612,7 +587,7 @@ def _run_indexing(
memory_tracer.stop()
raise e
else:
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
mark_attempt_failed(
index_attempt_id,
db_session_temp,
@@ -634,7 +609,7 @@ def _run_indexing(
memory_tracer.stop()
elapsed_time = time.monotonic() - start_time
with get_session_with_current_tenant() as db_session_temp:
with get_session_with_tenant(tenant_id) as db_session_temp:
# resolve entity-based errors
for error in entity_based_unresolved_errors:
logger.info(f"Resolving IndexAttemptError for entity '{error.entity_id}'")
@@ -679,7 +654,7 @@ def _run_indexing(
def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
@@ -694,12 +669,12 @@ def run_indexing_entrypoint(
TaskAttemptSingleton.set_cc_and_index_id(
index_attempt_id, connector_credential_pair_id
)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
# TODO: remove long running session entirely
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if MULTI_TENANT:
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name
@@ -715,7 +690,7 @@ def run_indexing_entrypoint(
f"credentials='{credential_id}'"
)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id) as db_session:
_run_indexing(db_session, index_attempt_id, tenant_id, callback)
logger.info(

View File

@@ -190,8 +190,7 @@ def create_chat_chain(
and previous_message.message_type == MessageType.ASSISTANT
and mainline_messages
):
if current_message.refined_answer_improvement:
mainline_messages[-1] = current_message
mainline_messages[-1] = current_message
else:
mainline_messages.append(current_message)

View File

@@ -142,15 +142,6 @@ class MessageResponseIDInfo(BaseModel):
reserved_assistant_message_id: int
class AgentMessageIDInfo(BaseModel):
level: int
message_id: int
class AgenticMessageResponseIDInfo(BaseModel):
agentic_message_ids: list[AgentMessageIDInfo]
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None

View File

@@ -11,8 +11,6 @@ from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallExcep
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona
from onyx.chat.models import AgenticMessageResponseIDInfo
from onyx.chat.models import AgentMessageIDInfo
from onyx.chat.models import AgentSearchPacket
from onyx.chat.models import AllCitations
from onyx.chat.models import AnswerPostInfo
@@ -145,10 +143,9 @@ from onyx.utils.long_term_log import LongTermLogger
from onyx.utils.telemetry import mt_cloud_telemetry
from onyx.utils.timing import log_function_time
from onyx.utils.timing import log_generator_function_time
from shared_configs.contextvars import get_current_tenant_id
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
ERROR_TYPE_CANCELLED = "cancelled"
def _translate_citations(
@@ -310,7 +307,6 @@ ChatPacket = (
| CustomToolResponse
| MessageSpecificCitations
| MessageResponseIDInfo
| AgenticMessageResponseIDInfo
| StreamStopInfo
| AgentSearchPacket
)
@@ -346,7 +342,7 @@ def stream_chat_message_objects(
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
4. [always] Details on the final AI response message that is created
"""
tenant_id = get_current_tenant_id()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
use_existing_user_message = new_msg_req.use_existing_user_message
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
@@ -635,7 +631,6 @@ def stream_chat_message_objects(
db_session=db_session,
commit=False,
reserved_message_id=reserved_message_id,
is_agentic=new_msg_req.use_agentic_search,
)
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
@@ -747,13 +742,14 @@ def stream_chat_message_objects(
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config, llm.config),
system_message=default_build_system_message(prompt_config),
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
raw_user_uploaded_files=latest_query_files or [],
single_message_history=single_message_history,
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
# LLM prompt building, response capturing, etc.
answer = Answer(
@@ -869,6 +865,7 @@ def stream_chat_message_objects(
for img in img_generation_response
if img.image_data
],
tenant_id=tenant_id,
)
info.ai_message_files.extend(
[
@@ -1018,7 +1015,7 @@ def stream_chat_message_objects(
if info.message_specific_citations
else None
),
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
error=None,
tool_call=(
ToolCall(
tool_id=tool_name_to_tool_id[info.tool_result.tool_name],
@@ -1036,7 +1033,6 @@ def stream_chat_message_objects(
next_level = 1
prev_message = gen_ai_response_message
agent_answers = answer.llm_answer_by_level()
agentic_message_ids = []
while next_level in agent_answers:
next_answer = agent_answers[next_level]
info = info_by_subq[
@@ -1057,12 +1053,7 @@ def stream_chat_message_objects(
citations=info.message_specific_citations.citation_map
if info.message_specific_citations
else None,
error=ERROR_TYPE_CANCELLED if answer.is_cancelled() else None,
refined_answer_improvement=refined_answer_improvement,
is_agentic=True,
)
agentic_message_ids.append(
AgentMessageIDInfo(level=next_level, message_id=next_answer_message.id)
)
next_level += 1
prev_message = next_answer_message
@@ -1070,9 +1061,11 @@ def stream_chat_message_objects(
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
yield AgenticMessageResponseIDInfo(agentic_message_ids=agentic_message_ids)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield translate_db_message_to_chat_message_detail(gen_ai_response_message)
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)

View File

@@ -12,7 +12,6 @@ from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_toke
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLMConfig
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
@@ -20,7 +19,6 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
@@ -33,16 +31,8 @@ from onyx.tools.tool import Tool
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
# for o-series markdown generation
if (
llm_config.model_provider == OPENAI_PROVIDER_NAME
and llm_config.model_name.startswith("o")
):
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
@@ -120,8 +110,21 @@ class AnswerPromptBuilder:
),
)
self.update_system_prompt(system_message)
self.update_user_prompt(user_message)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
(
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
if system_message
else None
)
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []

View File

@@ -178,13 +178,13 @@ AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
)
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 4 # in seconds
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
@@ -204,7 +204,7 @@ AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
@@ -217,7 +217,7 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 3 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
@@ -256,7 +256,7 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
@@ -269,7 +269,7 @@ AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 3 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
@@ -282,7 +282,7 @@ AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 2 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 1 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
@@ -295,7 +295,7 @@ AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
@@ -308,7 +308,7 @@ AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
@@ -321,7 +321,7 @@ AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION

View File

@@ -6,7 +6,6 @@ from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
#####
@@ -30,9 +29,6 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
)
#####
# Web Configs
@@ -630,8 +626,6 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true"
MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH")
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"

View File

@@ -98,18 +98,9 @@ CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120
CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120
# hard timeout applied by the watchdog to the indexing connector run
# to handle hung connectors
CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds)
# soft timeout for the lock taken by the indexing connector run
# allows the lock to eventually expire if the managing code around it dies
# needs to be long enough to cover the maximum time it takes to download an object
# if we can get callbacks as object bytes download, we could lower this a lot.
# CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 15 minutes
# hard termination should always fire first if the connector is hung
CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900
CELERY_INDEXING_LOCK_TIMEOUT = 3 * 60 * 60 # 60 min
# how long a task should wait for associated fence to be ready
CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min
@@ -213,12 +204,6 @@ class AuthType(str, Enum):
CLOUD = "cloud"
class QueryHistoryType(str, Enum):
DISABLED = "disabled"
ANONYMIZED = "anonymized"
NORMAL = "normal"
# Special characters for password validation
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
@@ -348,9 +333,6 @@ class OnyxRedisSignals:
BLOCK_PRUNING = "signal:block_pruning"
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
"signal:block_validate_connector_deletion_fences"
)
class OnyxRedisConstants:

View File

@@ -7,18 +7,11 @@ from typing import Optional
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from botocore.exceptions import PartialCredentialsError
from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import BlobType
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -247,73 +240,6 @@ class BlobStorageConnector(LoadConnector, PollConnector):
return None
def validate_connector_settings(self) -> None:
if self.s3_client is None:
raise ConnectorMissingCredentialError(
"Blob storage credentials not loaded."
)
if not self.bucket_name:
raise ConnectorValidationError(
"No bucket name was provided in connector settings."
)
try:
# We only fetch one object/page as a light-weight validation step.
# This ensures we trigger typical S3 permission checks (ListObjectsV2, etc.).
self.s3_client.list_objects_v2(
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
)
except NoCredentialsError:
raise ConnectorMissingCredentialError(
"No valid blob storage credentials found or provided to boto3."
)
except PartialCredentialsError:
raise ConnectorMissingCredentialError(
"Partial or incomplete blob storage credentials provided to boto3."
)
except ClientError as e:
error_code = e.response["Error"].get("Code", "")
status_code = e.response["ResponseMetadata"].get("HTTPStatusCode")
# Most common S3 error cases
if error_code in [
"AccessDenied",
"InvalidAccessKeyId",
"SignatureDoesNotMatch",
]:
if status_code == 403 or error_code == "AccessDenied":
raise InsufficientPermissionsError(
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
"Please check your bucket policy and/or IAM policy."
)
if status_code == 401 or error_code == "SignatureDoesNotMatch":
raise CredentialExpiredError(
"Provided blob storage credentials appear invalid or expired."
)
raise CredentialExpiredError(
f"Credential issue encountered ({error_code})."
)
if error_code == "NoSuchBucket" or status_code == 404:
raise ConnectorValidationError(
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
)
raise ConnectorValidationError(
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
)
except Exception as e:
# Catch-all for anything not captured by the above
# Since we are unsure of the error and it may not disable the connector,
# raise an unexpected error (does not disable connector)
raise UnexpectedError(
f"Unexpected error during blob storage settings validation: {e}"
)
if __name__ == "__main__":
credentials_dict = {

View File

@@ -5,8 +5,6 @@ import requests
class BookStackClientRequestFailedError(ConnectionError):
def __init__(self, status: int, error: str) -> None:
self.status_code = status
self.error = error
super().__init__(
"BookStack Client request failed with status {status}: {error}".format(
status=status, error=error

View File

@@ -7,11 +7,7 @@ from typing import Any
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.bookstack.client import BookStackApiClient
from onyx.connectors.bookstack.client import BookStackClientRequestFailedError
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -218,39 +214,3 @@ class BookstackConnector(LoadConnector, PollConnector):
break
else:
time.sleep(0.2)
def validate_connector_settings(self) -> None:
"""
Validate that the BookStack credentials and connector settings are correct.
Specifically checks that we can make an authenticated request to BookStack.
"""
if not self.bookstack_client:
raise ConnectorMissingCredentialError(
"BookStack credentials have not been loaded."
)
try:
# Attempt to fetch a small batch of books (arbitrary endpoint) to verify credentials
_ = self.bookstack_client.get(
"/books", params={"count": "1", "offset": "0"}
)
except BookStackClientRequestFailedError as e:
# Check for HTTP status codes
if e.status_code == 401:
raise CredentialExpiredError(
"Your BookStack credentials appear to be invalid or expired (HTTP 401)."
) from e
elif e.status_code == 403:
raise InsufficientPermissionsError(
"The configured BookStack token does not have sufficient permissions (HTTP 403)."
) from e
else:
raise ConnectorValidationError(
f"Unexpected BookStack error (status={e.status_code}): {e}"
) from e
except Exception as exc:
raise ConnectorValidationError(
f"Unexpected error while validating BookStack connector settings: {exc}"
) from exc

View File

@@ -4,8 +4,6 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from requests.exceptions import HTTPError
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
@@ -18,10 +16,6 @@ from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -403,33 +397,3 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list
def validate_connector_settings(self) -> None:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence credentials not loaded.")
try:
spaces = self._confluence_client.get_all_spaces(limit=1)
except HTTPError as e:
status_code = e.response.status_code if e.response else None
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Confluence credentials (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Insufficient permissions to access Confluence resources (HTTP 403)."
)
raise UnexpectedError(
f"Unexpected Confluence error (status={status_code}): {e}"
)
except Exception as e:
raise UnexpectedError(
f"Unexpected error while validating Confluence settings: {e}"
)
if not spaces or not spaces.get("results"):
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
"there truly are no spaces in this Confluence instance."
)

View File

@@ -8,12 +8,8 @@ from typing import TypeVar
from urllib.parse import quote
from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from requests import HTTPError
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -33,16 +29,6 @@ class ConfluenceRateLimitError(Exception):
pass
class ConfluenceUser(BaseModel):
user_id: str # accountId in Cloud, userKey in Server
username: str | None # Confluence Cloud doesn't give usernames
display_name: str
# Confluence Data Center doesn't give email back by default,
# have to fetch it with a different endpoint
email: str | None
type: str
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
@@ -163,7 +149,7 @@ class OnyxConfluence(Confluence):
)
def _paginate_url(
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
self, url_suffix: str, limit: int | None = None
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -238,41 +224,9 @@ class OnyxConfluence(Confluence):
raise e
# yield the results individually
results = cast(list[dict[str, Any]], next_response.get("results", []))
yield from results
yield from next_response.get("results", [])
old_url_suffix = url_suffix
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
# make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
if url_suffix and "start" in url_suffix:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start
adjusted_start = previous_start + len(results)
url_suffix = update_param_in_path(
url_suffix, "start", str(adjusted_start)
)
# some APIs don't properly paginate, so we need to manually update the `start` param
if auto_paginate and len(results) > 0:
previous_start = get_start_param_from_url(old_url_suffix)
updated_start = previous_start + len(results)
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
url_suffix = next_response.get("_links", {}).get("next")
def paginated_cql_retrieval(
self,
@@ -321,97 +275,21 @@ class OnyxConfluence(Confluence):
self,
expand: str | None = None,
limit: int | None = None,
) -> Iterator[ConfluenceUser]:
) -> Iterator[dict[str, Any]]:
"""
The search/user endpoint can be used to fetch users.
It's a seperate endpoint from the content/search endpoint used only for users.
Otherwise it's very similar to the content/search endpoint.
"""
if self.cloud:
cql = "type=user"
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
# endpoint doesn't properly paginate, so we need to manually update the `start` param
# thus the auto_paginate flag
for user_result in self._paginate_url(url, limit, auto_paginate=True):
# Example response:
# {
# 'user': {
# 'type': 'known',
# 'accountId': '712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'accountType': 'atlassian',
# 'email': 'chris@danswer.ai',
# 'publicName': 'Chris Weaver',
# 'profilePicture': {
# 'path': '/wiki/aa-avatar/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'width': 48,
# 'height': 48,
# 'isDefault': False
# },
# 'displayName': 'Chris Weaver',
# 'isExternalCollaborator': False,
# '_expandable': {
# 'operations': '',
# 'personalSpace': ''
# },
# '_links': {
# 'self': 'https://danswerai.atlassian.net/wiki/rest/api/user?accountId=712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d'
# }
# },
# 'title': 'Chris Weaver',
# 'excerpt': '',
# 'url': '/people/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d',
# 'breadcrumbs': [],
# 'entityType': 'user',
# 'iconCssClass': 'aui-icon content-type-profile',
# 'lastModified': '2025-02-18T04:08:03.579Z',
# 'score': 0.0
# }
user = user_result["user"]
yield ConfluenceUser(
user_id=user["accountId"],
username=None,
display_name=user["displayName"],
email=user.get("email"),
type=user["accountType"],
)
else:
# https://developer.atlassian.com/server/confluence/rest/v900/api-group-user/#api-rest-api-user-list-get
# ^ is only available on data center deployments
# Example response:
# [
# {
# 'type': 'known',
# 'username': 'admin',
# 'userKey': '40281082950c5fe901950c61c55d0000',
# 'profilePicture': {
# 'path': '/images/icons/profilepics/default.svg',
# 'width': 48,
# 'height': 48,
# 'isDefault': True
# },
# 'displayName': 'Admin Test',
# '_links': {
# 'self': 'http://localhost:8090/rest/api/user?key=40281082950c5fe901950c61c55d0000'
# },
# '_expandable': {
# 'status': ''
# }
# }
# ]
for user in self._paginate_url("rest/api/user/list", limit):
yield ConfluenceUser(
user_id=user["userKey"],
username=user["username"],
display_name=user["displayName"],
email=None,
type=user.get("type", "user"),
)
cql = "type=user"
url = "rest/api/search/user" if self.cloud else "rest/api/search"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
yield from self._paginate_url(url, limit)
def paginated_groups_by_user_retrieval(
self,
user_id: str, # accountId in Cloud, userKey in Server
user: dict[str, Any],
limit: int | None = None,
) -> Iterator[dict[str, Any]]:
"""
@@ -419,7 +297,7 @@ class OnyxConfluence(Confluence):
It's a confluence specific endpoint that can be used to fetch groups.
"""
user_field = "accountId" if self.cloud else "key"
user_value = user_id
user_value = user["accountId"] if self.cloud else user["userKey"]
# Server uses userKey (but calls it key during the API call), Cloud uses accountId
user_query = f"{user_field}={quote(user_value)}"
@@ -545,15 +423,11 @@ def build_confluence_client(
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
try:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
except Exception as e:
raise ConnectorValidationError(str(e))
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present

View File

@@ -2,10 +2,7 @@ import io
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import TYPE_CHECKING
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
@@ -13,13 +10,13 @@ from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.onyx_confluence import (
OnyxConfluence,
)
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
logger = setup_logger()
@@ -27,7 +24,7 @@ _USER_EMAIL_CACHE: dict[str, str | None] = {}
def get_user_email_from_username__server(
confluence_client: "OnyxConfluence", user_name: str
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
@@ -50,7 +47,7 @@ _USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
@@ -81,7 +78,7 @@ def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
def extract_text_from_confluence_html(
confluence_client: "OnyxConfluence",
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
@@ -194,7 +191,7 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
def attachment_to_content(
confluence_client: "OnyxConfluence",
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
@@ -282,32 +279,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)
return parse_qs(parsed_url.query).get(param, [None])[0]
def get_start_param_from_url(url: str) -> int:
"""Get the start parameter from a url"""
start_str = get_single_param_from_url(url, "start")
if start_str is None:
return 0
return int(start_str)
def update_param_in_path(path: str, param: str, value: str) -> str:
"""Update a parameter in a path. Path should look something like:
/api/rest/users?start=0&limit=10
"""
parsed_url = urlparse(path)
query_params = parse_qs(parsed_url.query)
query_params[param] = [value]
return (
path.split("?")[0]
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)

View File

@@ -1,4 +1,3 @@
import re
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
@@ -25,22 +24,16 @@ def datetime_to_utc(dt: datetime) -> datetime:
def time_str_to_utc(datetime_str: str) -> datetime:
# Remove all timezone abbreviations in parentheses
datetime_str = re.sub(r"\([A-Z]+\)", "", datetime_str).strip()
# Remove any remaining parentheses and their contents
datetime_str = re.sub(r"\(.*?\)", "", datetime_str).strip()
try:
dt = parse(datetime_str)
except ValueError:
# Fix common format issues (e.g. "0000" => "+0000")
# Handle malformed timezone by attempting to fix common format issues
if "0000" in datetime_str:
datetime_str = datetime_str.replace(" 0000", " +0000")
dt = parse(datetime_str)
# Convert "0000" to "+0000" for proper timezone parsing
fixed_dt_str = datetime_str.replace(" 0000", " +0000")
dt = parse(fixed_dt_str)
else:
raise
return datetime_to_utc(dt)

View File

@@ -4,15 +4,11 @@ from typing import Any
from dropbox import Dropbox # type: ignore
from dropbox.exceptions import ApiError # type:ignore
from dropbox.exceptions import AuthError # type:ignore
from dropbox.files import FileMetadata # type:ignore
from dropbox.files import FolderMetadata # type:ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialInvalidError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -145,29 +141,6 @@ class DropboxConnector(LoadConnector, PollConnector):
return None
def validate_connector_settings(self) -> None:
if self.dropbox_client is None:
raise ConnectorMissingCredentialError("Dropbox credentials not loaded.")
try:
self.dropbox_client.files_list_folder(path="", limit=1)
except AuthError as e:
logger.exception("Failed to validate Dropbox credentials")
raise CredentialInvalidError(f"Dropbox credential is invalid: {e.error}")
except ApiError as e:
if (
e.error is not None
and "insufficient_permissions" in str(e.error).lower()
):
raise InsufficientPermissionsError(
"Your Dropbox token does not have sufficient permissions."
)
raise ConnectorValidationError(
f"Unexpected Dropbox error during validation: {e.user_message_text or e}"
)
except Exception as e:
raise Exception(f"Unexpected error during Dropbox settings validation: {e}")
if __name__ == "__main__":
import os

View File

@@ -1,49 +0,0 @@
class ValidationError(Exception):
"""General exception for validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class ConnectorValidationError(ValidationError):
"""General exception for connector validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class UnexpectedError(ValidationError):
"""Raised when an unexpected error occurs during connector validation.
Unexpected errors don't necessarily mean the credential is invalid,
but rather that there was an error during the validation process
or we encountered a currently unhandled error case.
"""
def __init__(self, message: str = "Unexpected error during connector validation"):
super().__init__(message)
class CredentialInvalidError(ConnectorValidationError):
"""Raised when a connector's credential is invalid."""
def __init__(self, message: str = "Credential is invalid"):
super().__init__(message)
class CredentialExpiredError(ConnectorValidationError):
"""Raised when a connector's credential is expired."""
def __init__(self, message: str = "Credential has expired"):
super().__init__(message)
class InsufficientPermissionsError(ConnectorValidationError):
"""Raised when the credential does not have sufficient API permissions."""
def __init__(
self, message: str = "Insufficient permissions for the requested operation"
):
super().__init__(message)

View File

@@ -3,8 +3,8 @@ from typing import Type
from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
@@ -17,7 +17,6 @@ from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
from onyx.connectors.dropbox.connector import DropboxConnector
from onyx.connectors.egnyte.connector import EgnyteConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.fireflies.connector import FirefliesConnector
from onyx.connectors.freshdesk.connector import FreshdeskConnector
@@ -53,9 +52,7 @@ from onyx.connectors.wikipedia.connector import WikipediaConnector
from onyx.connectors.xenforo.connector import XenforoConnector
from onyx.connectors.zendesk.connector import ZendeskConnector
from onyx.connectors.zulip.connector import ZulipConnector
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.models import Credential
@@ -163,9 +160,13 @@ def instantiate_connector(
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
tenant_id: str | None = None,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
if source in DocumentSourceRequiringTenantContext:
connector_specific_config["tenant_id"] = tenant_id
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)
@@ -173,51 +174,3 @@ def instantiate_connector(
backend_update_credential_json(credential, new_credentials, db_session)
return connector
def validate_ccpair_for_user(
connector_id: int,
credential_id: int,
db_session: Session,
enforce_creation: bool = True,
) -> bool:
if INTEGRATION_TESTS_MODE:
return True
# Validate the connector settings
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(
credential_id,
db_session,
)
if not connector:
raise ValueError("Connector not found")
if (
connector.source == DocumentSource.INGESTION_API
or connector.source == DocumentSource.MOCK_CONNECTOR
):
return True
if not credential:
raise ValueError("Credential not found")
try:
runnable_connector = instantiate_connector(
db_session=db_session,
source=connector.source,
input_type=connector.input_type,
connector_specific_config=connector.connector_specific_config,
credential=credential,
)
except ConnectorValidationError as e:
raise e
except Exception as e:
if enforce_creation:
raise ConnectorValidationError(str(e))
else:
return False
runnable_connector.validate_connector_settings()
return True

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -27,6 +27,8 @@ from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -163,10 +165,12 @@ class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -175,8 +179,9 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(self.tenant_id) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
@@ -198,6 +203,8 @@ class LocalFileConnector(LoadConnector):
if documents:
yield documents
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if __name__ == "__main__":
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])

View File

@@ -187,12 +187,12 @@ class FirefliesConnector(PollConnector, LoadConnector):
return self._process_transcripts()
def poll_source(
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
self, start_unixtime: SecondsSinceUnixEpoch, end_unixtime: SecondsSinceUnixEpoch
) -> GenerateDocumentsOutput:
start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)
end_datetime = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
start_datetime = datetime.fromtimestamp(
start_unixtime, tz=timezone.utc
).strftime("%Y-%m-%dT%H:%M:%S.000Z")
end_datetime = datetime.fromtimestamp(end_unixtime, tz=timezone.utc).strftime(
"%Y-%m-%dT%H:%M:%S.000Z"
)

View File

@@ -229,20 +229,16 @@ class GitbookConnector(LoadConnector, PollConnector):
try:
content = self.client.get(f"/spaces/{self.space_id}/content")
pages: list[dict[str, Any]] = content.get("pages", [])
pages = content.get("pages", [])
current_batch: list[Document] = []
for page in pages:
updated_at = datetime.fromisoformat(page["updatedAt"])
while pages:
page = pages.pop(0)
updated_at_raw = page.get("updatedAt")
if updated_at_raw is None:
# if updatedAt is not present, that means the page has never been edited
continue
updated_at = datetime.fromisoformat(updated_at_raw)
if start and updated_at < start:
continue
if current_batch:
yield current_batch
return
if end and updated_at > end:
continue
@@ -254,8 +250,6 @@ class GitbookConnector(LoadConnector, PollConnector):
yield current_batch
current_batch = []
pages.extend(page.get("pages", []))
if current_batch:
yield current_batch

View File

@@ -9,7 +9,6 @@ from typing import cast
from github import Github
from github import RateLimitExceededException
from github import Repository
from github.GithubException import GithubException
from github.Issue import Issue
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
@@ -17,10 +16,6 @@ from github.PullRequest import PullRequest
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -31,6 +26,7 @@ from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -124,7 +120,7 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repo_name: str | None = None,
repo_name: str,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
@@ -162,81 +158,53 @@ class GithubConnector(LoadConnector, PollConnector):
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
# Try to get organization first
try:
org = github_client.get_organization(self.repo_owner)
return list(org.get_repos())
except GithubException:
# If not an org, try as a user
user = github_client.get_user(self.repo_owner)
return list(user.get_repos())
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
repo = self._get_github_repo(self.github_client)
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
if self.include_prs:
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
return
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
if self.include_issues:
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
return
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
@@ -258,63 +226,6 @@ class GithubConnector(LoadConnector, PollConnector):
return self._fetch_from_github(adjusted_start_datetime, end_datetime)
def validate_connector_settings(self) -> None:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
if not self.repo_owner:
raise ConnectorValidationError(
"Invalid connector settings: 'repo_owner' must be provided."
)
try:
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
org = self.github_client.get_organization(self.repo_owner)
org.get_repos().totalCount # Just check if we can access repos
except GithubException:
# If not an org, try as a user
user = self.github_client.get_user(self.repo_owner)
user.get_repos().totalCount # Just check if we can access repos
except RateLimitExceededException:
raise UnexpectedError(
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
)
except GithubException as e:
if e.status == 401:
raise CredentialExpiredError(
"GitHub credential appears to be invalid or expired (HTTP 401)."
)
elif e.status == 403:
raise InsufficientPermissionsError(
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
)
if __name__ == "__main__":
import os

View File

@@ -297,7 +297,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
full_threads = execute_paginated_retrieval(
retrieval_function=gmail_service.users().threads().get,
@@ -305,7 +304,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
continue_on_404_or_403=True,
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
@@ -337,7 +335,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
doc_batch.append(
SlimDocument(

View File

@@ -13,9 +13,6 @@ from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.google_drive.doc_conversion import build_slim_document
from onyx.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
@@ -45,7 +42,6 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -141,7 +137,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
raise ConnectorValidationError(
raise ValueError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
@@ -155,7 +151,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
and not my_drive_emails
and not shared_drive_urls
):
raise ConnectorValidationError(
raise ValueError(
"Nothing to index. Please specify at least one of the following: "
"include_shared_drives, include_my_drives, include_files_shared_with_me, "
"shared_folder_urls, or my_drive_emails"
@@ -224,14 +220,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
return self._creds
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None:
try:
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
except KeyError:
raise ValueError(
"Primary admin email missing, "
"should not call this property "
"before calling load_credentials"
)
self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY]
self._creds, new_creds_dict = get_google_creds(
credentials=credentials,
@@ -613,50 +602,3 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def validate_connector_settings(self) -> None:
if self._creds is None:
raise ConnectorMissingCredentialError(
"Google Drive credentials not loaded."
)
if self._primary_admin_email is None:
raise ConnectorValidationError(
"Primary admin email not found in credentials. "
"Ensure DB_CREDENTIALS_PRIMARY_ADMIN_KEY is set."
)
try:
drive_service = get_drive_service(self._creds, self._primary_admin_email)
drive_service.files().list(pageSize=1, fields="files(id)").execute()
if isinstance(self._creds, ServiceAccountCredentials):
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
status_code = e.resp.status if e.resp else None
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Google Drive credentials (401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Google Drive app lacks required permissions (403). "
"Please ensure the necessary scopes are granted and Drive "
"apps are enabled."
)
else:
raise ConnectorValidationError(
f"Unexpected Google Drive error (status={status_code}): {e}"
)
except Exception as e:
# Check for scope-related hints from the error message
if MISSING_SCOPES_ERROR_STR in str(e):
raise InsufficientPermissionsError(
"Google Drive credentials are missing required scopes. "
f"{ONYX_SCOPE_INSTRUCTIONS}"
)
raise ConnectorValidationError(
f"Unexpected error during Google Drive validation: {e}"
)

View File

@@ -1,9 +1,7 @@
import io
from datetime import datetime
from datetime import timezone
from tempfile import NamedTemporaryFile
import openpyxl # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
@@ -45,15 +43,12 @@ def _extract_sections_basic(
) -> list[Section]:
mime_type = file["mimeType"]
link = file["webViewLink"]
supported_file_types = set(item.value for item in GDriveMimeType)
if mime_type not in supported_file_types:
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
# ---------------------------
# Google Sheets extraction
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
@@ -114,53 +109,7 @@ def _extract_sections_basic(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
# ---------------------------
# Microsoft Excel (.xlsx or .xls) extraction branch
elif mime_type in [
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
]:
try:
response = service.files().get_media(fileId=file["id"]).execute()
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
tmp.write(response)
tmp_path = tmp.name
section_separator = "\n\n"
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
# Work similarly to the xlsx_to_text function used for file connector
# but returns Sections instead of a string
sections = [
Section(
link=link,
text=(
f"Sheet: {sheet.title}\n\n"
+ section_separator.join(
",".join(map(str, row))
for row in sheet.iter_rows(
min_row=1, values_only=True
)
if row
)
),
)
for sheet in workbook.worksheets
]
return sections
except Exception as e:
logger.warning(
f"Error extracting data from Excel file '{file['name']}': {e}"
)
return [
Section(link=link, text="Error extracting data from Excel file")
]
# ---------------------------
# Export for Google Docs, PPT, and fallback for spreadsheets
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
@@ -179,8 +128,6 @@ def _extract_sections_basic(
)
return [Section(link=link, text=text)]
# ---------------------------
# Plain text and Markdown files
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
@@ -194,8 +141,6 @@ def _extract_sections_basic(
.decode("utf-8"),
)
]
# ---------------------------
# Word, PowerPoint, PDF files
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
@@ -225,11 +170,7 @@ def _extract_sections_basic(
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
error_message = f"Unsupported file type: {mime_type}"
logger.error(error_message)
raise ValueError(error_message)
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]

View File

@@ -5,10 +5,6 @@ from typing import Any
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
SPREADSHEET_OPEN_FORMAT = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"

View File

@@ -87,18 +87,16 @@ class HubSpotConnector(LoadConnector, PollConnector):
contact = api_client.crm.contacts.basic_api.get_by_id(
contact_id=contact.id
)
email = contact.properties.get("email")
if email is not None:
associated_emails.append(email)
associated_emails.append(contact.properties["email"])
if notes:
for note in notes.results:
note = api_client.crm.objects.notes.basic_api.get_by_id(
note_id=note.id, properties=["content", "hs_body_preview"]
)
preview = note.properties.get("hs_body_preview")
if preview is not None:
associated_notes.append(preview)
if note.properties["hs_body_preview"] is None:
continue
associated_notes.append(note.properties["hs_body_preview"])
associated_emails_str = " ,".join(associated_emails)
associated_notes_str = " ".join(associated_notes)

View File

@@ -12,6 +12,7 @@ from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
SecondsSinceUnixEpoch = float
GenerateDocumentsOutput = Iterator[list[Document]]
@@ -44,14 +45,6 @@ class BaseConnector(abc.ABC):
raise RuntimeError(custom_parser_req_msg)
return metadata_lines
def validate_connector_settings(self) -> None:
"""
Override this if your connector needs to validate credentials or settings.
Raise an exception if invalid, otherwise do nothing.
Default is a no-op (always successful).
"""
# Large set update or reindex, generally pulling a complete state or from a savestate file
class LoadConnector(BaseConnector):

View File

@@ -7,7 +7,6 @@ from datetime import timezone
from typing import Any
from typing import Optional
import requests
from retry import retry
from onyx.configs.app_configs import INDEX_BATCH_SIZE
@@ -16,15 +15,10 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
@@ -622,64 +616,6 @@ class NotionConnector(LoadConnector, PollConnector):
else:
break
def validate_connector_settings(self) -> None:
if not self.headers.get("Authorization"):
raise ConnectorMissingCredentialError("Notion credentials not loaded.")
try:
# We'll do a minimal search call (page_size=1) to confirm accessibility
if self.root_page_id:
# If root_page_id is set, fetch the specific page
res = rl_requests.get(
f"https://api.notion.com/v1/pages/{self.root_page_id}",
headers=self.headers,
timeout=_NOTION_CALL_TIMEOUT,
)
else:
# If root_page_id is not set, perform a minimal search
test_query = {
"filter": {"property": "object", "value": "page"},
"page_size": 1,
}
res = rl_requests.post(
"https://api.notion.com/v1/search",
headers=self.headers,
json=test_query,
timeout=_NOTION_CALL_TIMEOUT,
)
res.raise_for_status()
except requests.exceptions.HTTPError as http_err:
status_code = http_err.response.status_code if http_err.response else None
if status_code == 401:
raise CredentialExpiredError(
"Notion credential appears to be invalid or expired (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Notion token does not have sufficient permissions (HTTP 403)."
)
elif status_code == 404:
# Typically means resource not found or not shared. Could be root_page_id is invalid.
raise ConnectorValidationError(
"Notion resource not found or not shared with the integration (HTTP 404)."
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Notion rate-limits being exceeded (HTTP 429). "
"Please try again later."
)
else:
raise UnexpectedError(
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
) from http_err
except Exception as exc:
raise UnexpectedError(
f"Unexpected error during Notion settings validation: {exc}"
)
if __name__ == "__main__":
import os

View File

@@ -12,9 +12,6 @@ from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -29,6 +26,7 @@ from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
from onyx.connectors.onyx_jira.utils import build_jira_client
from onyx.connectors.onyx_jira.utils import build_jira_url
from onyx.connectors.onyx_jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
from onyx.connectors.onyx_jira.utils import get_comment_strs
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -159,8 +157,7 @@ def fetch_jira_issues_batch(
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
jira_base_url: str,
project_key: str | None = None,
jira_project_url: str,
comment_email_blacklist: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
# if a ticket has one of the labels specified in this list, we will just
@@ -169,12 +166,11 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None:
self.batch_size = batch_size
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
self.jira_project = project_key
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self._jira_client: JIRA | None = None
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
@property
def comment_email_blacklist(self) -> tuple:
@@ -189,9 +185,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
if not self.jira_project:
return ""
return f'"{self.jira_project}"'
return f'"{self._jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
@@ -200,14 +194,8 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _get_jql_query(self) -> str:
"""Get the JQL query based on whether a specific project is set"""
if self.jira_project:
return f"project = {self.quoted_jira_project}"
return "" # Empty string means all accessible projects
def load_from_state(self) -> GenerateDocumentsOutput:
jql = self._get_jql_query()
jql = f"project = {self.quoted_jira_project}"
document_batch = []
for doc in fetch_jira_issues_batch(
@@ -234,10 +222,11 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
base_jql = self._get_jql_query()
jql = (
f"{base_jql} AND " if base_jql else ""
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
f"project = {self.quoted_jira_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)
document_batch = []
for doc in fetch_jira_issues_batch(
@@ -260,7 +249,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = self._get_jql_query()
jql = f"project = {self.quoted_jira_project}"
slim_doc_batch = []
for issue in _paginate_jql_search(
@@ -283,67 +272,13 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
yield slim_doc_batch
def validate_connector_settings(self) -> None:
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
# If a specific project is set, validate it exists
if self.jira_project:
try:
self.jira_client.project(self.jira_project)
except Exception as e:
status_code = getattr(e, "status_code", None)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
)
elif status_code == 404:
raise ConnectorValidationError(
f"Jira project not found with key: {self.jira_project}"
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
else:
# If no project specified, validate we can access the Jira API
try:
# Try to list projects to validate access
self.jira_client.projects()
except Exception as e:
status_code = getattr(e, "status_code", None)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions to list projects (HTTP 403)."
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
if __name__ == "__main__":
import os
connector = JiraConnector(
jira_base_url=os.environ["JIRA_BASE_URL"],
project_key=os.environ.get("JIRA_PROJECT_KEY"),
comment_email_blacklist=[],
os.environ["JIRA_PROJECT_URL"], comment_email_blacklist=[]
)
connector.load_credentials(
{
"jira_user_email": os.environ["JIRA_USER_EMAIL"],

View File

@@ -18,10 +18,6 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
@@ -86,14 +82,14 @@ def get_channels(
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace."""
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# Try fetching both public and private channels first:
# try getting private channels as well at first
try:
channels = _collect_paginated_channels(
client=client,
@@ -101,19 +97,19 @@ def get_channels(
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(
f"Unable to fetch private channels due to: {e}. Trying again without private channels."
)
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch.")
logger.warning("No channels to fetch")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return channels
@@ -670,86 +666,6 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
return checkpoint
def validate_connector_settings(self) -> None:
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists.
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
try:
# 1) Validate connection to workspace
auth_response = self.client.auth_test()
if not auth_response.get("ok", False):
error_msg = auth_response.get(
"error", "Unknown error from Slack auth_test"
)
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
# 2) Minimal test to confirm listing channels works
test_resp = self.client.conversations_list(
limit=1, types=["public_channel"]
)
if not test_resp.get("ok", False):
error_msg = test_resp.get("error", "Unknown error from Slack")
if error_msg == "invalid_auth":
raise ConnectorValidationError(
f"Invalid Slack bot token ({error_msg})."
)
elif error_msg == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({error_msg})."
)
raise UnexpectedError(f"Slack API returned a failure: {error_msg}")
# 3) If channels are specified, verify each is accessible
if self.channels:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,
get_public=True,
get_private=True,
)
# For quick lookups by name or ID, build a map:
accessible_channel_names = {ch["name"] for ch in accessible_channels}
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
for user_channel in self.channels:
if (
user_channel not in accessible_channel_names
and user_channel not in accessible_channel_ids
):
raise ConnectorValidationError(
f"Channel '{user_channel}' not found or inaccessible in this workspace."
)
except SlackApiError as e:
slack_error = e.response.get("error", "")
if slack_error == "missing_scope":
raise InsufficientPermissionsError(
"Slack bot token lacks the necessary scope to list/access channels. "
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
)
elif slack_error == "invalid_auth":
raise CredentialExpiredError(
f"Invalid Slack bot token ({slack_error})."
)
elif slack_error == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
raise UnexpectedError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
except ConnectorValidationError as e:
raise e
except Exception as e:
raise UnexpectedError(
f"Unexpected error during Slack settings validation: {e}"
)
if __name__ == "__main__":
import os

View File

@@ -5,7 +5,6 @@ from typing import Any
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.runtime.client_request_exception import ClientRequestException # type: ignore
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.chats.messages.message import ChatMessage # type: ignore
from office365.teams.team import Team # type: ignore
@@ -13,10 +12,6 @@ from office365.teams.team import Team # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -284,50 +279,6 @@ class TeamsConnector(LoadConnector, PollConnector):
end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
def validate_connector_settings(self) -> None:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
try:
# Minimal call to confirm we can retrieve Teams
found_teams = self._get_all_teams()
except ClientRequestException as e:
status_code = e.response.status_code
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()
if (
"unauthorized" in error_str
or "401" in error_str
or "invalid_grant" in error_str
):
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials."
)
elif "forbidden" in error_str or "403" in error_str:
raise InsufficientPermissionsError(
"App lacks required permissions to read from Microsoft Teams."
)
raise ConnectorValidationError(
f"Unexpected error during Teams validation: {e}"
)
if not found_teams:
raise ConnectorValidationError(
"No Teams found for the given credentials. "
"Either there are no Teams in this tenant, or your app does not have permission to view them."
)
if __name__ == "__main__":
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))

View File

@@ -25,10 +25,6 @@ from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_TOKEN_URL
from onyx.configs.app_configs import WEB_CONNECTOR_VALIDATE_URLS
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
@@ -41,8 +37,6 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
# Given a base site, index everything under that path
@@ -176,35 +170,26 @@ def start_playwright() -> Tuple[Playwright, BrowserContext]:
def extract_urls_from_sitemap(sitemap_url: str) -> list[str]:
try:
response = requests.get(sitemap_url)
response.raise_for_status()
response = requests.get(sitemap_url)
response.raise_for_status()
soup = BeautifulSoup(response.content, "html.parser")
urls = [
_ensure_absolute_url(sitemap_url, loc_tag.text)
for loc_tag in soup.find_all("loc")
]
soup = BeautifulSoup(response.content, "html.parser")
urls = [
_ensure_absolute_url(sitemap_url, loc_tag.text)
for loc_tag in soup.find_all("loc")
]
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
# the given url doesn't look like a sitemap, let's try to find one
urls = list_pages_for_site(sitemap_url)
if len(urls) == 0 and len(soup.find_all("urlset")) == 0:
# the given url doesn't look like a sitemap, let's try to find one
urls = list_pages_for_site(sitemap_url)
if len(urls) == 0:
raise ValueError(
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
)
return urls
except requests.RequestException as e:
raise RuntimeError(f"Failed to fetch sitemap from {sitemap_url}: {e}")
except ValueError as e:
raise RuntimeError(f"Error processing sitemap {sitemap_url}: {e}")
except Exception as e:
raise RuntimeError(
f"Unexpected error while processing sitemap {sitemap_url}: {e}"
if len(urls) == 0:
raise ValueError(
f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead."
)
return urls
def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str:
if not urlparse(maybe_relative_url).netloc:
@@ -240,14 +225,10 @@ class WebConnector(LoadConnector):
web_connector_type: str = WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value,
mintlify_cleanup: bool = True, # Mostly ok to apply to other websites as well
batch_size: int = INDEX_BATCH_SIZE,
scroll_before_scraping: bool = False,
**kwargs: Any,
) -> None:
self.mintlify_cleanup = mintlify_cleanup
self.batch_size = batch_size
self.recursive = False
self.scroll_before_scraping = scroll_before_scraping
self.web_connector_type = web_connector_type
if web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value:
self.recursive = True
@@ -363,18 +344,6 @@ class WebConnector(LoadConnector):
continue
visited_links.add(current_url)
if self.scroll_before_scraping:
scroll_attempts = 0
previous_height = page.evaluate("document.body.scrollHeight")
while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS:
page.evaluate("window.scrollTo(0, document.body.scrollHeight)")
page.wait_for_load_state("networkidle", timeout=30000)
new_height = page.evaluate("document.body.scrollHeight")
if new_height == previous_height:
break # Stop scrolling when no more content is loaded
previous_height = new_height
scroll_attempts += 1
content = page.content()
soup = BeautifulSoup(content, "html.parser")
@@ -433,56 +402,6 @@ class WebConnector(LoadConnector):
raise RuntimeError(last_error)
raise RuntimeError("No valid pages found.")
def validate_connector_settings(self) -> None:
# Make sure we have at least one valid URL to check
if not self.to_visit_list:
raise ConnectorValidationError(
"No URL configured. Please provide at least one valid URL."
)
if (
self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value
or self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value
):
return None
# We'll just test the first URL for connectivity and correctness
test_url = self.to_visit_list[0]
# Check that the URL is allowed and well-formed
try:
protected_url_check(test_url)
except ValueError as e:
raise ConnectorValidationError(
f"Protected URL check failed for '{test_url}': {e}"
)
except ConnectionError as e:
# Typically DNS or other network issues
raise ConnectorValidationError(str(e))
# Make a quick request to see if we get a valid response
try:
check_internet_connection(test_url)
except Exception as e:
err_str = str(e)
if "401" in err_str:
raise CredentialExpiredError(
f"Unauthorized access to '{test_url}': {e}"
)
elif "403" in err_str:
raise InsufficientPermissionsError(
f"Forbidden access to '{test_url}': {e}"
)
elif "404" in err_str:
raise ConnectorValidationError(f"Page not found for '{test_url}': {e}")
elif "Max retries exceeded" in err_str and "NameResolutionError" in err_str:
raise ConnectorValidationError(
f"Unable to resolve hostname for '{test_url}'. Please check the URL and your internet connection."
)
else:
# Could be a 5xx or another error, treat as unexpected
raise UnexpectedError(f"Unexpected error validating '{test_url}': {e}")
if __name__ == "__main__":
connector = WebConnector("https://docs.onyx.app/")

View File

@@ -23,6 +23,7 @@ from onyx.context.search.preprocessing.access_filters import (
from onyx.context.search.retrieval.search_runner import (
remove_stop_words_and_punctuation,
)
from onyx.db.engine import CURRENT_TENANT_ID_CONTEXTVAR
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
from onyx.llm.interfaces import LLM
@@ -34,7 +35,6 @@ from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -166,7 +166,7 @@ def retrieval_preprocessing(
time_cutoff=time_filter or predicted_time_cutoff,
tags=preset_filters.tags, # Tags are never auto-extracted
access_control_list=user_acl_filters,
tenant_id=get_current_tenant_id() if MULTI_TENANT else None,
tenant_id=CURRENT_TENANT_ID_CONTEXTVAR.get() if MULTI_TENANT else None,
)
llm_evaluation_type = LLMEvaluationType.BASIC

View File

@@ -16,7 +16,8 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from shared_configs.contextvars import get_current_tenant_id
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
def get_api_key_email_pattern() -> str:
@@ -70,9 +71,9 @@ def insert_api_key(
std_password_helper = PasswordHelper()
# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = get_current_tenant_id()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
api_key = generate_api_key(tenant_id)
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
api_key_user_id = uuid.uuid4()
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER

View File

@@ -629,7 +629,6 @@ def create_new_chat_message(
reserved_message_id: int | None = None,
overridden_model: str | None = None,
refined_answer_improvement: bool | None = None,
is_agentic: bool = False,
) -> ChatMessage:
if reserved_message_id is not None:
# Edit existing message
@@ -651,7 +650,7 @@ def create_new_chat_message(
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
existing_message.refined_answer_improvement = refined_answer_improvement
existing_message.is_agentic = is_agentic
new_chat_message = existing_message
else:
# Create new message
@@ -671,7 +670,6 @@ def create_new_chat_message(
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
refined_answer_improvement=refined_answer_improvement,
is_agentic=is_agentic,
)
db_session.add(new_chat_message)
@@ -962,7 +960,6 @@ def translate_db_message_to_chat_message_detail(
chat_message.sub_questions
),
refined_answer_improvement=chat_message.refined_answer_improvement,
error=chat_message.error,
)
return chat_msg_detail

View File

@@ -1,152 +0,0 @@
from typing import List
from typing import Optional
from typing import Tuple
from uuid import UUID
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import literal
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import union_all
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
def search_chat_sessions(
user_id: UUID | None,
db_session: Session,
query: Optional[str] = None,
page: int = 1,
page_size: int = 10,
include_deleted: bool = False,
include_onyxbot_flows: bool = False,
) -> Tuple[List[ChatSession], bool]:
"""
Search for chat sessions based on the provided query.
If no query is provided, returns recent chat sessions.
Returns a tuple of (chat_sessions, has_more)
"""
offset = (page - 1) * page_size
# If no search query, we use standard SQLAlchemy pagination
if not query or not query.strip():
stmt = select(ChatSession)
if user_id:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_created))
# Apply pagination
stmt = stmt.offset(offset).limit(page_size + 1)
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
has_more = len(chat_sessions) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
return list(chat_sessions), has_more
words = query.lower().strip().split()
# Message mach subquery
message_matches = []
for word in words:
word_like = f"%{word}%"
message_match: Select = (
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
.where(func.lower(ChatMessage.message).like(word_like))
)
if user_id:
message_match = message_match.where(ChatSession.user_id == user_id)
message_matches.append(message_match)
if message_matches:
message_matches_query = union_all(*message_matches).alias("message_matches")
else:
return [], False
# Description matches
description_match: Select = select(
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
if user_id:
description_match = description_match.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
description_match = description_match.where(ChatSession.deleted.is_(False))
# Combine all match sources
combined_matches = union_all(
message_matches_query.select(), description_match
).alias("combined_matches")
# Use CTE to group and get max rank
session_ranks = (
select(
combined_matches.c.chat_session_id,
func.max(combined_matches.c.search_rank).label("rank"),
)
.group_by(combined_matches.c.chat_session_id)
.alias("session_ranks")
)
# Get ranked sessions with pagination
ranked_query = (
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
.offset(offset)
.limit(page_size + 1)
)
result = ranked_query.all()
# Extract session IDs and ranks
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
session_ids = list(session_ids_with_ranks.keys())
if not session_ids:
return [], False
# Now, let's query the actual ChatSession objects using the IDs
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
if user_id:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
# Full objects with eager loading
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
chat_sessions = result.scalars().all()
# Sort based on above ranking
chat_sessions = sorted(
chat_sessions,
key=lambda session: (
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
session.time_created.timestamp() * -1, # Then by time (newest first)
),
)
has_more = len(chat_sessions) > page_size
if has_more:
chat_sessions = chat_sessions[:page_size]
return chat_sessions, has_more

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import TypeVarTuple
from fastapi import HTTPException
from sqlalchemy import delete
@@ -9,18 +8,15 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
@@ -35,12 +31,10 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
R = TypeVarTuple("R")
def _add_user_filters(
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
) -> Select[tuple[*R]]:
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -104,52 +98,17 @@ def get_connector_credential_pairs_for_user(
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
if eager_load_user:
load_opts = load_opts.joinedload(Credential.user)
stmt = stmt.options(load_opts)
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).unique().all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_connector_credential_pairs_for_user_parallel(
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
get_editable,
ids,
eager_load_connector,
eager_load_credential,
eager_load_user,
)
return list(db_session.scalars(stmt).all())
def get_connector_credential_pairs(
@@ -192,16 +151,6 @@ def get_cc_pair_groups_for_ids(
return list(db_session.scalars(stmt).all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
@@ -245,14 +194,9 @@ def get_connector_credential_pair_from_id_for_user(
def get_connector_credential_pair_from_id(
db_session: Session,
cc_pair_id: int,
eager_load_credential: bool = False,
) -> ConnectorCredentialPair | None:
stmt = select(ConnectorCredentialPair).distinct()
stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id)
if eager_load_credential:
stmt = stmt.options(joinedload(ConnectorCredentialPair.credential))
result = db_session.execute(stmt)
return result.scalar_one_or_none()
@@ -452,8 +396,8 @@ def add_credential_to_connector(
# If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user
if seeding_flow:
credential = fetch_credential_by_id(
credential_id=credential_id,
db_session=db_session,
credential_id=credential_id,
)
else:
credential = fetch_credential_by_id_for_user(

View File

@@ -14,7 +14,6 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY,
)
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
@@ -169,8 +168,8 @@ def fetch_credential_by_id_for_user(
def fetch_credential_by_id(
credential_id: int,
db_session: Session,
credential_id: int,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
@@ -246,10 +245,6 @@ def swap_credentials_connector(
existing_pair.credential_id = new_credential_id
existing_pair.credential = new_credential
# Update ccpair status if it's in INVALID state
if existing_pair.status == ConnectorCredentialPairStatus.INVALID:
existing_pair.status = ConnectorCredentialPairStatus.ACTIVE
# Commit the changes
db_session.commit()
@@ -422,8 +417,8 @@ def create_initial_public_credential(db_session: Session) -> None:
"There must exist an empty public credential for data connectors that do not require additional Auth."
)
first_credential = fetch_credential_by_id(
credential_id=PUBLIC_CREDENTIAL_ID,
db_session=db_session,
credential_id=PUBLIC_CREDENTIAL_ID,
)
if first_credential is not None:

View File

@@ -24,7 +24,6 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -61,8 +60,9 @@ def count_documents_by_needs_sync(session: Session) -> int:
This function executes the query and returns the count of
documents matching the criteria."""
return (
session.query(DbDocument.id)
count = (
session.query(func.count(DbDocument.id.distinct()))
.select_from(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
@@ -73,53 +73,63 @@ def count_documents_by_needs_sync(session: Session) -> int:
DbDocument.last_synced.is_(None),
)
)
.count()
.scalar()
)
return count
def construct_document_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
return (
select(DbDocument)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
),
)
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
select(DbDocument)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)
return stmt
def construct_document_id_select_for_connector_credential_pair_by_needs_sync(
connector_id: int, credential_id: int
) -> Select:
return (
select(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DbDocument.id == DocumentByConnectorCredentialPair.id,
)
.where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
or_(
DbDocument.last_modified > DbDocument.last_synced,
DbDocument.last_synced.is_(None),
),
)
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
stmt = (
select(DbDocument.id)
.where(
DbDocument.id.in_(initial_doc_ids_stmt),
or_(
DbDocument.last_modified
> DbDocument.last_synced, # last_modified is newer than last_synced
DbDocument.last_synced.is_(None), # never synced
),
)
.distinct()
)
return stmt
def get_all_documents_needing_vespa_sync_for_cc_pair(
db_session: Session, cc_pair_id: int
@@ -230,12 +240,12 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
# Prepare a list of (connector_id, credential_id) tuples
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
stmt = (
select(
@@ -261,16 +271,6 @@ def get_document_counts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
) -> Sequence[tuple[int, int, int]]:
with get_session_context_manager() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
def get_access_info_for_document(
db_session: Session,
document_id: str,

View File

@@ -1,4 +1,5 @@
import contextlib
import json
import os
import re
import ssl
@@ -15,6 +16,7 @@ from typing import ContextManager
import asyncpg # type: ignore
import boto3
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import event
from sqlalchemy import pool
from sqlalchemy import text
@@ -42,13 +44,13 @@ from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME
from onyx.configs.constants import SSL_CERT_FILE
from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -218,7 +220,6 @@ class SqlEngine:
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
@@ -258,13 +259,13 @@ class SqlEngine:
cls._engine = None
def get_all_tenant_ids() -> list[str]:
def get_all_tenant_ids() -> list[str] | list[None]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
if not MULTI_TENANT:
return [POSTGRES_DEFAULT_SCHEMA]
return [None]
with get_session_with_shared_schema() as session:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
result = session.execute(
text(
f"""
@@ -352,6 +353,38 @@ def get_sqlalchemy_async_engine() -> AsyncEngine:
return _ASYNC_ENGINE
async def get_current_tenant_id(request: Request) -> str:
if not MULTI_TENANT:
tenant_id = POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:
current_value = CURRENT_TENANT_ID_CONTEXTVAR.get()
logger.debug(
f"Token data not found or expired in Redis, defaulting to {current_value}"
)
return current_value
tenant_id = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
return tenant_id
except json.JSONDecodeError:
logger.error("Error decoding token data from Redis")
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in get_current_tenant_id: {str(e)}")
raise HTTPException(status_code=500, detail="Internal server error")
# Listen for events on the synchronous Session class
@event.listens_for(Session, "after_begin")
def _set_search_path(
@@ -377,7 +410,7 @@ async def get_async_session_with_tenant(
tenant_id: str | None = None,
) -> AsyncGenerator[AsyncSession, None]:
if tenant_id is None:
tenant_id = get_current_tenant_id()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not is_valid_schema_name(tenant_id):
logger.error(f"Invalid tenant ID: {tenant_id}")
@@ -400,80 +433,82 @@ async def get_async_session_with_tenant(
@contextmanager
def get_session_with_current_tenant() -> Generator[Session, None, None]:
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
def get_session_with_default_tenant() -> Generator[Session, None, None]:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
# Used in multi tenant mode when need to refer to the shared `public` schema
@contextmanager
def get_session_with_shared_schema() -> Generator[Session, None, None]:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session:
yield session
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@contextmanager
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
def get_session_with_tenant(
tenant_id: str | None = None,
) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
This function:
1. Sets the database schema to the specified tenant's schema.
2. Preserves the tenant ID across the session.
3. Reverts to the previous tenant ID after the session is closed.
4. Uses the default schema if no tenant ID is provided.
"""
engine = get_sqlalchemy_engine()
previous_tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() or POSTGRES_DEFAULT_SCHEMA
if tenant_id is None:
tenant_id = POSTGRES_DEFAULT_SCHEMA
engine = get_sqlalchemy_engine()
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
event.listen(engine, "checkout", set_search_path_on_checkout)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID")
with engine.connect() as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
cursor.close()
with Session(bind=connection, expire_on_commit=False) as session:
try:
with engine.connect() as connection:
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
yield session
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
)
finally:
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
cursor.close()
with Session(bind=connection, expire_on_commit=False) as session:
try:
yield session
finally:
if MULTI_TENANT:
cursor = dbapi_connection.cursor()
try:
cursor.execute('SET search_path TO "$user", public')
finally:
cursor.close()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.set(previous_tenant_id)
def set_search_path_on_checkout(
dbapi_conn: Any, connection_record: Any, connection_proxy: Any
) -> None:
tenant_id = get_current_tenant_id()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id and is_valid_schema_name(tenant_id):
with dbapi_conn.cursor() as cursor:
cursor.execute(f'SET search_path TO "{tenant_id}"')
def get_session_generator_with_tenant() -> Generator[Session, None, None]:
tenant_id = get_current_tenant_id()
with get_session_with_tenant(tenant_id=tenant_id) as session:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
with get_session_with_tenant(tenant_id) as session:
yield session
def get_session() -> Generator[Session, None, None]:
tenant_id = get_current_tenant_id()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
raise BasicAuthenticationError(detail="User must authenticate")
@@ -488,7 +523,7 @@ def get_session() -> Generator[Session, None, None]:
async def get_async_session() -> AsyncGenerator[AsyncSession, None]:
tenant_id = get_current_tenant_id()
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
engine = get_sqlalchemy_async_engine()
async with AsyncSession(engine, expire_on_commit=False) as async_session:
if MULTI_TENANT:

View File

@@ -73,7 +73,6 @@ class ConnectorCredentialPairStatus(str, PyEnum):
ACTIVE = "ACTIVE"
PAUSED = "PAUSED"
DELETING = "DELETING"
INVALID = "INVALID"
def is_active(self) -> bool:
return self == ConnectorCredentialPairStatus.ACTIVE

View File

@@ -2,7 +2,6 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TypeVarTuple
from sqlalchemy import and_
from sqlalchemy import delete
@@ -10,13 +9,9 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@@ -373,33 +368,19 @@ def get_latest_index_attempts_by_status(
return db_session.execute(stmt).scalars().all()
T = TypeVarTuple("T")
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
return stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
def get_latest_index_attempts(
secondary_index: bool,
db_session: Session,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
ids_stmt = ids_stmt.where(SearchSettings.status == status)
if only_finished:
ids_stmt = _add_only_finished_clause(ids_stmt)
if secondary_index:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
else:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
@@ -414,53 +395,7 @@ def get_latest_index_attempts(
.where(IndexAttempt.id == ids_subquery.c.max_id)
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return db_session.execute(stmt).scalars().unique().all()
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_latest_index_attempts_parallel(
secondary_index: bool,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,
eager_load_cc_pair,
only_finished,
)
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
return db_session.execute(stmt).scalars().all()
def count_index_attempts_for_connector(
@@ -518,12 +453,37 @@ def get_paginated_index_attempts_for_cc_pair_id(
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
stmt = stmt.options(
contains_eager(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return list(db_session.execute(stmt).scalars().unique().all())
return list(db_session.execute(stmt).scalars().all())
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def get_index_attempts_for_cc_pair(

View File

@@ -148,12 +148,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
putting here for simpicity
"""
temperature_override_enabled: Mapped[bool | None] = mapped_column(
Boolean, default=None
)
auto_scroll: Mapped[bool | None] = mapped_column(Boolean, default=None)
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
@@ -205,13 +204,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
@property
def password_configured(self) -> bool:
"""
Returns True if the user has at least one OAuth (or OIDC) account.
"""
return not bool(self.oauth_accounts)
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
pass
@@ -350,9 +342,7 @@ class Document__Tag(Base):
document_id: Mapped[str] = mapped_column(
ForeignKey("document.id"), primary_key=True
)
tag_id: Mapped[int] = mapped_column(
ForeignKey("tag.id"), primary_key=True, index=True
)
tag_id: Mapped[int] = mapped_column(ForeignKey("tag.id"), primary_key=True)
class Persona__Tool(Base):
@@ -570,14 +560,6 @@ class Document(Base):
back_populates="documents",
)
__table_args__ = (
Index(
"ix_document_sync_status",
last_modified,
last_synced,
),
)
class Tag(Base):
__tablename__ = "tag"
@@ -1239,7 +1221,6 @@ class ChatMessage(Base):
DateTime(timezone=True), server_default=func.now()
)
is_agentic: Mapped[bool] = mapped_column(Boolean, default=False)
refined_answer_improvement: Mapped[bool] = mapped_column(Boolean, nullable=True)
chat_session: Mapped[ChatSession] = relationship("ChatSession")
@@ -1761,7 +1742,6 @@ class ChannelConfig(TypedDict):
# If empty list, follow up with no tags
follow_up_tags: NotRequired[list[str]]
show_continue_in_web_ui: NotRequired[bool] # defaults to False
disabled: NotRequired[bool] # defaults to False
class SlackChannelConfig(Base):
@@ -1785,7 +1765,6 @@ class SlackChannelConfig(Base):
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
persona: Mapped[Persona | None] = relationship("Persona")
slack_bot: Mapped["SlackBot"] = relationship(
"SlackBot",
back_populates="slack_channel_configs",

View File

@@ -100,14 +100,9 @@ def _add_user_filters(
.correlate(Persona)
)
else:
# Group the public persona conditions
public_condition = (Persona.is_public == True) & ( # noqa: E712
Persona.is_visible == True # noqa: E712
)
where_clause |= public_condition
where_clause |= Persona.is_public == True # noqa: E712
where_clause &= Persona.is_visible == True # noqa: E712
where_clause |= Persona__User.user_id == user.id
where_clause |= Persona.user_id == user.id
return stmt.where(where_clause)

View File

@@ -13,7 +13,7 @@ from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from onyx.context.search.models import SavedSearchSettings
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_default_tenant
from onyx.db.llm import fetch_embedding_provider
from onyx.db.models import CloudEmbeddingProvider
from onyx.db.models import IndexAttempt
@@ -189,7 +189,7 @@ def get_all_search_settings(db_session: Session) -> list[SearchSettings]:
def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
if db_session is None:
with get_session_with_current_tenant() as db_session:
with get_session_with_default_tenant() as db_session:
search_settings = get_current_search_settings(db_session)
else:
search_settings = get_current_search_settings(db_session)

View File

@@ -151,7 +151,6 @@ def update_slack_channel_config(
channel_config: ChannelConfig,
standard_answer_category_ids: list[int],
enable_auto_filters: bool,
disabled: bool,
) -> SlackChannelConfig:
slack_channel_config = db_session.scalar(
select(SlackChannelConfig).where(

View File

@@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
# Vespa's Document API.
def get_document_chunk_ids(
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
tenant_id: str,
tenant_id: str | None,
large_chunks_enabled: bool,
) -> list[UUID]:
doc_chunk_ids = []
@@ -139,7 +139,7 @@ def get_uuid_from_chunk_info(
*,
document_id: str,
chunk_id: int,
tenant_id: str,
tenant_id: str | None,
large_chunk_id: int | None = None,
) -> UUID:
"""NOTE: be VERY carefuly about changing this function. If changed without a migration,
@@ -154,7 +154,7 @@ def get_uuid_from_chunk_info(
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
)
unique_identifier_string = "_".join([doc_str, chunk_index])
if MULTI_TENANT:
if tenant_id and MULTI_TENANT:
unique_identifier_string += "_" + tenant_id
uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)

View File

@@ -43,7 +43,7 @@ class IndexBatchParams:
doc_id_to_previous_chunk_cnt: dict[str, int | None]
doc_id_to_new_chunk_cnt: dict[str, int]
tenant_id: str
tenant_id: str | None
large_chunks_enabled: bool
@@ -222,7 +222,7 @@ class Deletable(abc.ABC):
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
) -> int:
"""
@@ -249,7 +249,7 @@ class Updatable(abc.ABC):
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:
@@ -270,7 +270,9 @@ class Updatable(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
def update(
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
) -> None:
"""
Updates some set of chunks. The document and fields to update are specified in the update
requests. Each update request in the list applies its changes to a list of document ids.

Some files were not shown because too many files have changed in this diff Show More