Compare commits

..

1 Commits

Author SHA1 Message Date
pablonyx
827e693fac fix starter message editing 2025-02-23 10:52:56 -08:00
281 changed files with 3723 additions and 10060 deletions

1
.github/CODEOWNERS vendored
View File

@@ -1 +0,0 @@
* @onyx-dot-app/onyx-core-team

View File

@@ -53,90 +53,24 @@ jobs:
exclude: '(?i)^(pylint|aio[-_]*).*'
- name: Print report
if: always()
if: ${{ always() }}
run: echo "${{ steps.license_check_report.outputs.report }}"
- name: Install npm dependencies
working-directory: ./web
run: npm ci
- name: Run Trivy vulnerability scanner in repo mode
uses: aquasecurity/trivy-action@0.28.0
with:
scan-type: fs
scanners: license
format: table
# format: sarif
# output: trivy-results.sarif
severity: HIGH,CRITICAL
# be careful enabling the sarif and upload as it may spam the security tab
# with a huge amount of items. Work out the issues before enabling upload.
# - name: Run Trivy vulnerability scanner in repo mode
# if: always()
# uses: aquasecurity/trivy-action@0.29.0
# - name: Upload Trivy scan results to GitHub Security tab
# uses: github/codeql-action/upload-sarif@v3
# with:
# scan-type: fs
# scan-ref: .
# scanners: license
# format: table
# severity: HIGH,CRITICAL
# # format: sarif
# # output: trivy-results.sarif
#
# # - name: Upload Trivy scan results to GitHub Security tab
# # uses: github/codeql-action/upload-sarif@v3
# # with:
# # sarif_file: trivy-results.sarif
scan-trivy:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# Backend
- name: Pull backend docker image
run: docker pull onyxdotapp/onyx-backend:latest
- name: Run Trivy vulnerability scanner on backend
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-backend:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
# Web server
- name: Pull web server docker image
run: docker pull onyxdotapp/onyx-web-server:latest
- name: Run Trivy vulnerability scanner on web server
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-web-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# Model server
- name: Pull model server docker image
run: docker pull onyxdotapp/onyx-model-server:latest
- name: Run Trivy vulnerability scanner
uses: aquasecurity/trivy-action@0.29.0
env:
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
with:
image-ref: onyxdotapp/onyx-model-server:latest
scanners: license
severity: HIGH,CRITICAL
vuln-type: library
exit-code: 0
# sarif_file: trivy-results.sarif

View File

@@ -17,13 +17,8 @@ env:
AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }}
AWS_REGION_NAME: ${{ secrets.AWS_REGION_NAME }}
# API keys for testing
COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }}
LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }}
LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }}
# OpenAI
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }}
AZURE_API_URL: ${{ secrets.AZURE_API_URL }}
jobs:
model-check:
@@ -77,7 +72,7 @@ jobs:
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
docker compose -f docker-compose.model-server-test.yml -p onyx-stack up -d indexing_model_server
docker compose -f docker-compose.dev.yml -p onyx-stack up -d indexing_model_server
id: start_docker
- name: Wait for service to be ready
@@ -128,22 +123,9 @@ jobs:
--data '{"text":"Scheduled Model Tests failed! Check the run at: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}"}' \
$SLACK_WEBHOOK
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.model-server-test.yml -p onyx-stack down -v
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -1,84 +0,0 @@
"""improved index
Revision ID: 3bd4c84fe72f
Revises: 8f43500ee275
Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
down_revision = "8f43500ee275"
branch_labels = None
depends_on = None
# NOTE:
# This migration addresses issues with the previous migration (8f43500ee275) which caused
# an outage by creating an index without using CONCURRENTLY. This migration:
#
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
# (see: https://github.com/sqlalchemy/alembic/issues/277)
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade() -> None:
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Commit again before creating the second concurrent index
op.execute("COMMIT")
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)
def downgrade() -> None:
# Drop the indexes first (use CONCURRENTLY for dropping too)
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
op.execute("COMMIT")
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
# Then drop the columns
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -1,32 +0,0 @@
"""add index
Revision ID: 8f43500ee275
Revises: da42808081e3
Create Date: 2025-02-24 17:35:33.072714
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "8f43500ee275"
down_revision = "da42808081e3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create a basic index on the lowercase message column for direct text matching
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
# op.execute(
# """
# CREATE INDEX idx_chat_message_message_lower
# ON chat_message (LOWER(substring(message, 1, 1500)))
# """
# )
pass
def downgrade() -> None:
# Drop the index
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")

View File

@@ -1,55 +0,0 @@
"""add background_reindex_enabled field
Revision ID: b7c2b63c4a03
Revises: f11b408e39d3
Create Date: 2024-03-26 12:34:56.789012
"""
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import EmbeddingPrecision
# revision identifiers, used by Alembic.
revision = "b7c2b63c4a03"
down_revision = "f11b408e39d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add background_reindex_enabled column with default value of True
op.add_column(
"search_settings",
sa.Column(
"background_reindex_enabled",
sa.Boolean(),
nullable=False,
server_default="true",
),
)
# Add embedding_precision column with default value of FLOAT
op.add_column(
"search_settings",
sa.Column(
"embedding_precision",
sa.Enum(EmbeddingPrecision, native_enum=False),
nullable=False,
server_default=EmbeddingPrecision.FLOAT.name,
),
)
# Add reduced_dimension column with default value of None
op.add_column(
"search_settings",
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
)
def downgrade() -> None:
# Remove the background_reindex_enabled column
op.drop_column("search_settings", "background_reindex_enabled")
op.drop_column("search_settings", "embedding_precision")
op.drop_column("search_settings", "reduced_dimension")

View File

@@ -1,120 +0,0 @@
"""migrate jira connectors to new format
Revision ID: da42808081e3
Revises: f13db29f3101
Create Date: 2025-02-24 11:24:54.396040
"""
from alembic import op
import sqlalchemy as sa
import json
from onyx.configs.constants import DocumentSource
from onyx.connectors.onyx_jira.utils import extract_jira_project
# revision identifiers, used by Alembic.
revision = "da42808081e3"
down_revision = "f13db29f3101"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config
for connector_id, old_config in jira_connectors:
if not old_config:
continue
# Extract project key from URL if it exists
new_config: dict[str, str | None] = {}
if project_url := old_config.get("jira_project_url"):
# Parse the URL to get base and project
try:
jira_base, project_key = extract_jira_project(project_url)
new_config = {"jira_base_url": jira_base, "project_key": project_key}
except ValueError:
# If URL parsing fails, just use the URL as the base
new_config = {
"jira_base_url": project_url.split("/projects/")[0],
"project_key": None,
}
else:
# For connectors without a project URL, we need admin intervention
# Mark these for review
print(
f"WARNING: Jira connector {connector_id} has no project URL configured"
)
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :id
"""
),
{"id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get all Jira connectors
conn = op.get_bind()
# First get all Jira connectors
jira_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = :source
"""
),
{"source": DocumentSource.JIRA.value.upper()},
).fetchall()
# Update each connector's config back to the old format
for connector_id, new_config in jira_connectors:
if not new_config:
continue
old_config = {}
base_url = new_config.get("jira_base_url")
project_key = new_config.get("project_key")
if base_url and project_key:
old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"}
elif base_url:
old_config = {"jira_project_url": base_url}
else:
continue
# Update the connector config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :old_config
WHERE id = :id
"""
),
{"id": connector_id, "old_config": old_config},
)

View File

@@ -1,36 +0,0 @@
"""force lowercase all users
Revision ID: f11b408e39d3
Revises: 3bd4c84fe72f
Create Date: 2025-02-26 17:04:55.683500
"""
# revision identifiers, used by Alembic.
revision = "f11b408e39d3"
down_revision = "3bd4c84fe72f"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing user emails to lowercase
from alembic import op
op.execute(
"""
UPDATE "user"
SET email = LOWER(email)
"""
)
# 2) Add a check constraint to ensure emails are always lowercase
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
def downgrade() -> None:
# Drop the check constraint
from alembic import op
op.drop_constraint("ensure_lowercase_email", "user", type_="check")

View File

@@ -1,42 +0,0 @@
"""lowercase multi-tenant user auth
Revision ID: 34e3630c7f32
Revises: a4f6ee863c47
Create Date: 2025-02-26 15:03:01.211894
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "34e3630c7f32"
down_revision = "a4f6ee863c47"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Convert all existing rows to lowercase
op.execute(
"""
UPDATE user_tenant_mapping
SET email = LOWER(email)
"""
)
# 2) Add a check constraint so that emails cannot be written in uppercase
op.create_check_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
"email = LOWER(email)",
schema="public",
)
def downgrade() -> None:
# Drop the check constraint
op.drop_constraint(
"ensure_lowercase_email",
"user_tenant_mapping",
schema="public",
type_="check",
)

View File

@@ -5,9 +5,11 @@ from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -16,8 +18,10 @@ logger = setup_logger()
@build_celery_task_wrapper(name_chat_ttl_task)
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
def perform_ttl_management_task(
retention_limit_days: int, *, tenant_id: str | None
) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
delete_chat_sessions_older_than(retention_limit_days, db_session)
@@ -31,19 +35,24 @@ def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) ->
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def check_ttl_management_task(*, tenant_id: str) -> None:
def check_ttl_management_task(*, tenant_id: str | None) -> None:
"""Runs periodically to check if any ttl tasks should be run and adds them
to the queue"""
token = None
if MULTI_TENANT and tenant_id is not None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
settings = load_settings()
retention_limit_days = settings.maximum_chat_retention_days
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if should_perform_chat_ttl_check(retention_limit_days, db_session):
perform_ttl_management_task.apply_async(
kwargs=dict(
retention_limit_days=retention_limit_days, tenant_id=tenant_id
),
)
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@celery_app.task(
@@ -51,9 +60,9 @@ def check_ttl_management_task(*, tenant_id: str) -> None:
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
)
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
"""This generates usage report under the /admin/generate-usage/report endpoint"""
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_new_usage_report(
db_session=db_session,
user_id=None,

View File

@@ -18,7 +18,7 @@ logger = setup_logger()
def monitor_usergroup_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
"""This function is likely to move in the worker refactor happening next."""
fence_key = key_bytes.decode("utf-8")

View File

@@ -59,14 +59,10 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
)
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
)
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""

View File

@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import UserGroup__ConnectorCredentialPair
@@ -36,11 +35,10 @@ def _delete_connector_credential_pair_user_groups_relationship__no_commit(
def get_cc_pairs_by_source(
db_session: Session,
source_type: DocumentSource,
access_type: AccessType | None = None,
status: ConnectorCredentialPairStatus | None = None,
only_sync: bool,
) -> list[ConnectorCredentialPair]:
"""
Get all cc_pairs for a given source type with optional filtering by access_type and status
Get all cc_pairs for a given source type (and optionally only sync)
result is sorted by cc_pair id
"""
query = (
@@ -50,11 +48,8 @@ def get_cc_pairs_by_source(
.order_by(ConnectorCredentialPair.id)
)
if access_type is not None:
query = query.filter(ConnectorCredentialPair.access_type == access_type)
if status is not None:
query = query.filter(ConnectorCredentialPair.status == status)
if only_sync:
query = query.filter(ConnectorCredentialPair.access_type == AccessType.SYNC)
cc_pairs = query.all()
return cc_pairs

View File

@@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC only if they were previously a CURATOR
# otherwise, set their role to BASIC
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
@@ -631,16 +631,7 @@ def update_user_group(
removed_users = db_session.scalars(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
# Filter out admin and global curator users before validating curator status
users_to_validate = [
user
for user in removed_users
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
]
if users_to_validate:
_validate_curator_status__no_commit(db_session, users_to_validate)
_validate_curator_status__no_commit(db_session, list(removed_users))
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()

View File

@@ -9,16 +9,12 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -346,8 +342,7 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -359,11 +354,7 @@ def confluence_doc_sync(
confluence_connector = ConfluenceConnector(
**cc_pair.connector.connector_specific_config
)
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), "confluence", cc_pair.credential_id
)
confluence_connector.set_credentials_provider(provider)
confluence_connector.load_credentials(cc_pair.credential.credential_json)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)

View File

@@ -1,11 +1,9 @@
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
from onyx.background.error_logging import emit_background_error
from onyx.connectors.confluence.onyx_confluence import (
get_user_email_from_username__server,
)
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
@@ -63,27 +61,13 @@ def _build_group_member_email_map(
def confluence_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
url = wiki_base.rstrip("/")
probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
confluence_client = OnyxConfluence(is_cloud, url, provider)
confluence_client._probe_connection(**probe_kwargs)
confluence_client._initialize_connection(**final_kwargs)
confluence_client = build_confluence_client(
credentials=cc_pair.credential.credential_json,
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
)
group_member_email_map = _build_group_member_email_map(
confluence_client=confluence_client,

View File

@@ -32,8 +32,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -62,14 +62,12 @@ def _fetch_permissions_for_permission_ids(
user_email=(owner_email or google_drive_connector.primary_admin_email),
)
# We continue on 404 or 403 because the document may not exist or the user may not have access to it
fetched_permissions = execute_paginated_retrieval(
retrieval_function=drive_service.permissions().list,
list_key="permissions",
fileId=doc_id,
fields="permissions(id, emailAddress, type, domain)",
supportsAllDrives=True,
continue_on_404_or_403=True,
)
permissions_for_doc_id = []
@@ -106,13 +104,7 @@ def _get_permissions_from_slim_doc(
user_emails: set[str] = set()
group_emails: set[str] = set()
public = False
skipped_permissions = 0
for permission in permissions_list:
if not permission:
skipped_permissions += 1
continue
permission_type = permission["type"]
if permission_type == "user":
user_emails.add(permission["emailAddress"])
@@ -129,11 +121,6 @@ def _get_permissions_from_slim_doc(
elif permission_type == "anyone":
public = True
if skipped_permissions > 0:
logger.warning(
f"Skipped {skipped_permissions} permissions of {len(permissions_list)} for document {slim_doc.id}"
)
drive_id = permission_info.get("drive_id")
group_ids = group_emails | ({drive_id} if drive_id is not None else set())
@@ -145,8 +132,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -119,7 +119,6 @@ def _build_onyx_groups(
def gdrive_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> list[ExternalUserGroup]:
# Initialize connector and build credential/service objects

View File

@@ -123,8 +123,7 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
callback: IndexingHeartbeatInterface | None,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres

View File

@@ -28,7 +28,6 @@ DocSyncFuncType = Callable[
GroupSyncFuncType = Callable[
[
str,
ConnectorCredentialPair,
],
list[ExternalUserGroup],

View File

@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth.api import router as oauth_router
from ee.onyx.server.oauth import router as oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -152,8 +152,4 @@ def get_application() -> FastAPI:
# environment variable. Used to automate deployment for multiple environments.
seed_db()
# for debugging discovered routes
# for route in application.router.routes:
# print(f"Path: {route.path}, Methods: {route.methods}")
return application

View File

@@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.utils.logger import OnyxLoggingAdapter
from onyx.utils.logger import setup_logger
@@ -216,7 +216,7 @@ def _handle_standard_answers(
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread_or_channel(
respond_in_thread(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
@@ -231,7 +231,6 @@ def _handle_standard_answers(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
receiver_ids=receiver_ids,
)
return True

View File

@@ -0,0 +1,629 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from onyx.auth.users import current_user
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/oauth")
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.onyx.app/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class ConfluenceCloudOAuth:
"""work in progress"""
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.onyx.app/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_user),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
tenant_id = get_current_tenant_id()
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str
if connector == DocumentSource.SLACK:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
# elif connector == DocumentSource.CONFLUENCE:
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
# session = ConfluenceCloudOAuth.session_dump_json(
# email=user.email, redirect_on_success=redirect_on_success
# )
# elif connector == DocumentSource.JIRA:
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": SlackOAuth.REDIRECT_URI,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"team_id": team_id,
"authed_user_id": authed_user_id,
"redirect_on_success": session.redirect_on_success,
}
)
# Work in progress
# @router.post("/connector/confluence/callback")
# def handle_confluence_oauth_callback(
# code: str,
# state: str,
# user: User = Depends(current_user),
# db_session: Session = Depends(get_session),
# tenant_id: str | None = Depends(get_current_tenant_id),
# ) -> JSONResponse:
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
# raise HTTPException(
# status_code=500,
# detail="Confluence client ID or client secret is not configured."
# )
# r = get_redis_client(tenant_id=tenant_id)
# # recover the state
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
# # Convert bytes back to a UUID
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
# oauth_uuid_str = str(oauth_uuid)
# r_key = f"da_oauth:{oauth_uuid_str}"
# result = r.get(r_key)
# if not result:
# raise HTTPException(
# status_code=400,
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
# )
# try:
# session = ConfluenceCloudOAuth.parse_session(result)
# # Exchange the authorization code for an access token
# response = requests.post(
# ConfluenceCloudOAuth.TOKEN_URL,
# headers={"Content-Type": "application/x-www-form-urlencoded"},
# data={
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
# "code": code,
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
# },
# )
# response_data = response.json()
# if not response_data.get("ok"):
# raise HTTPException(
# status_code=400,
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
# )
# # Extract token and team information
# access_token: str = response_data.get("access_token")
# team_id: str = response_data.get("team", {}).get("id")
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
# credential_info = CredentialBase(
# credential_json={"slack_bot_token": access_token},
# admin_public=True,
# source=DocumentSource.CONFLUENCE,
# name="Confluence OAuth",
# )
# logger.info(f"Slack access token: {access_token}")
# credential = create_credential(credential_info, user, db_session)
# logger.info(f"new_credential_id={credential.id}")
# except Exception as e:
# return JSONResponse(
# status_code=500,
# content={
# "success": False,
# "message": f"An error occurred during Slack OAuth: {str(e)}",
# },
# )
# finally:
# r.delete(r_key)
# # return the result
# return JSONResponse(
# content={
# "success": True,
# "message": "Slack OAuth completed successfully.",
# "team_id": team_id,
# "authed_user_id": authed_user_id,
# "redirect_on_success": session.redirect_on_success,
# }
# )
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client()
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
session: GoogleDriveOAuth.OAuthSession
try:
session = GoogleDriveOAuth.parse_session(session_json)
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -1,91 +0,0 @@
import base64
import uuid
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from ee.onyx.server.oauth.api_router import router
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
from ee.onyx.server.oauth.slack import SlackOAuth
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import DocumentSource
from onyx.db.engine import get_current_tenant_id
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@router.post("/prepare-authorization-request")
def prepare_authorization_request(
connector: DocumentSource,
redirect_on_success: str | None,
user: User = Depends(current_admin_user),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Used by the frontend to generate the url for the user's browser during auth request.
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
"""
# create random oauth state param for security and to retrieve user data later
oauth_uuid = uuid.uuid4()
oauth_uuid_str = str(oauth_uuid)
# urlsafe b64 encode the uuid for the oauth url
oauth_state = (
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
)
session: str | None = None
if connector == DocumentSource.SLACK:
if not DEV_MODE:
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
session = SlackOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.CONFLUENCE:
if not DEV_MODE:
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
session = ConfluenceCloudOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
elif connector == DocumentSource.GOOGLE_DRIVE:
if not DEV_MODE:
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
else:
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
session = GoogleDriveOAuth.session_dump_json(
email=user.email, redirect_on_success=redirect_on_success
)
else:
oauth_url = None
if not oauth_url:
raise HTTPException(
status_code=404,
detail=f"The document source type {connector} does not have OAuth implemented",
)
if not session:
raise HTTPException(
status_code=500,
detail=f"The document source type {connector} failed to generate an OAuth session.",
)
r = get_redis_client(tenant_id=tenant_id)
# store important session state to retrieve when the user is redirected back
# 10 min is the max we want an oauth flow to be valid
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
return JSONResponse(content={"url": oauth_url})

View File

@@ -1,3 +0,0 @@
from fastapi import APIRouter
router: APIRouter = APIRouter(prefix="/oauth")

View File

@@ -1,361 +0,0 @@
import base64
import uuid
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from pydantic import ValidationError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
from onyx.db.credentials import create_credential
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.credentials import update_credential_json
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
from onyx.utils.logger import setup_logger
logger = setup_logger()
class ConfluenceCloudOAuth:
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
class AccessibleResources(BaseModel):
id: str
name: str
url: str
scopes: list[str]
avatarUrl: str
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
ACCESSIBLE_RESOURCE_URL = (
"https://api.atlassian.com/oauth/token/accessible-resources"
)
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
CONFLUENCE_OAUTH_SCOPE = (
# classic scope
"read:confluence-space.summary%20"
"read:confluence-props%20"
"read:confluence-content.all%20"
"read:confluence-content.summary%20"
"read:confluence-content.permission%20"
"read:confluence-user%20"
"read:confluence-groups%20"
"readonly:content.attachment:confluence%20"
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"offline_access"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
# eventually for Confluence Data Center
# oauth_url = (
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
# f"&redirect_uri={redirectme_uri}"
# )
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
url = (
"https://auth.atlassian.com/authorize"
f"?audience=api.atlassian.com"
f"&client_id={cls.CLIENT_ID}"
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
f"&redirect_uri={redirect_uri}"
f"&state={state}"
"&response_type=code"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = ConfluenceCloudOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
return session
@classmethod
def generate_finalize_url(cls, credential_id: int) -> str:
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
@router.post("/connector/confluence/callback")
def confluence_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Handles the backend logic for the frontend page that the user is redirected to
after visiting the oauth authorization url."""
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Confluence Cloud client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = ConfluenceCloudOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
else:
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
ConfluenceCloudOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
try:
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
response.text
)
except Exception:
raise RuntimeError(
"Confluence Cloud OAuth failed during code/token exchange."
)
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
credential_info = CredentialBase(
credential_json={
"confluence_access_token": token_response.access_token,
"confluence_refresh_token": token_response.refresh_token,
"created_at": now.isoformat(),
"expires_at": expires_at.isoformat(),
"expires_in": token_response.expires_in,
"scope": token_response.scope,
},
admin_public=True,
source=DocumentSource.CONFLUENCE,
name="Confluence Cloud OAuth",
)
credential = create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth completed successfully.",
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
"redirect_on_success": session.redirect_on_success,
}
)
@router.get("/connector/confluence/accessible-resources")
def confluence_oauth_accessible_resources(
credential_id: int,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Atlassian's API is weird and does not supply us with enough info to be in a
usable state after authorizing. All API's require a cloud id. We have to list
the accessible resources/sites and let the user choose which site to use."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(400, f"Credential {credential_id} not found.")
credential_dict = credential.credential_json
access_token = credential_dict["confluence_access_token"]
try:
# Exchange the authorization code for an access token
response = requests.get(
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
headers={
"Authorization": f"Bearer {access_token}",
"Accept": "application/json",
},
)
response.raise_for_status()
accessible_resources_data = response.json()
# Validate the list of AccessibleResources
try:
accessible_resources = [
ConfluenceCloudOAuth.AccessibleResources(**resource)
for resource in accessible_resources_data
]
except ValidationError as e:
raise RuntimeError(f"Failed to parse accessible resources: {e}")
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud get accessible resources completed successfully.",
"accessible_resources": [
resource.model_dump() for resource in accessible_resources
],
}
)
@router.post("/connector/confluence/finalize")
def confluence_oauth_finalize(
credential_id: int,
cloud_id: str,
cloud_name: str,
cloud_url: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
"""Saves the info for the selected cloud site to the credential.
This is the final step in the confluence oauth flow where after the traditional
OAuth process, the user has to select a site to associate with the credentials.
After this, the credential is usable."""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if not credential:
raise HTTPException(
status_code=400,
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
)
new_credential_json: dict[str, Any] = dict(credential.credential_json)
new_credential_json["cloud_id"] = cloud_id
new_credential_json["cloud_name"] = cloud_name
new_credential_json["wiki_base"] = cloud_url
try:
update_credential_json(credential_id, new_credential_json, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
},
)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Confluence Cloud OAuth finalized successfully.",
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
}
)

View File

@@ -1,229 +0,0 @@
import base64
import json
import uuid
from typing import Any
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_DICT_TOKEN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
)
from onyx.connectors.google_utils.shared_constants import (
GoogleOAuthAuthenticationMethod,
)
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class GoogleDriveOAuth:
# https://developers.google.com/identity/protocols/oauth2
# https://developers.google.com/identity/protocols/oauth2/web-server
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
TOKEN_URL = "https://oauth2.googleapis.com/token"
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
SCOPE = (
"https://www.googleapis.com/auth/drive.readonly%20"
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
"https://www.googleapis.com/auth/admin.directory.group.readonly"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
# without prompt=consent, a refresh token is only issued the first time the user approves
url = (
f"https://accounts.google.com/o/oauth2/v2/auth"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
"&response_type=code"
f"&scope={cls.SCOPE}"
"&access_type=offline"
f"&state={state}"
"&prompt=consent"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = GoogleDriveOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/google-drive/callback")
def handle_google_drive_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Google Drive client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = GoogleDriveOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
else:
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
GoogleDriveOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": GoogleDriveOAuth.CLIENT_ID,
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
"grant_type": "authorization_code",
},
)
response.raise_for_status()
authorization_response: dict[str, Any] = response.json()
# the connector wants us to store the json in its authorized_user_info format
# returned from OAuthCredentials.get_authorized_user_info().
# So refresh immediately via get_google_oauth_creds with the params filled in
# from fields in authorization_response to get the json we need
authorized_user_info = {}
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
token_json_str = json.dumps(authorized_user_info)
oauth_creds = get_google_oauth_creds(
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
)
if not oauth_creds:
raise RuntimeError("get_google_oauth_creds returned None.")
# save off the credentials
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
credential_dict: dict[str, str] = {}
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
credential_dict[
DB_CREDENTIALS_AUTHENTICATION_METHOD
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
credential_info = CredentialBase(
credential_json=credential_dict,
admin_public=True,
source=DocumentSource.GOOGLE_DRIVE,
name="OAuth (interactive)",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Google Drive OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
}
)

View File

@@ -1,197 +0,0 @@
import base64
import uuid
from typing import cast
import requests
from fastapi import Depends
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
from ee.onyx.server.oauth.api_router import router
from onyx.auth.users import current_admin_user
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import DocumentSource
from onyx.db.credentials import create_credential
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.redis.redis_pool import get_redis_client
from onyx.server.documents.models import CredentialBase
class SlackOAuth:
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
class OAuthSession(BaseModel):
"""Stored in redis to be looked up on callback"""
email: str
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
# SCOPE is per https://docs.danswer.dev/connectors/slack
BOT_SCOPE = (
"channels:history,"
"channels:read,"
"groups:history,"
"groups:read,"
"channels:join,"
"im:history,"
"users:read,"
"users:read.email,"
"usergroups:read"
)
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
@classmethod
def generate_oauth_url(cls, state: str) -> str:
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
@classmethod
def generate_dev_oauth_url(cls, state: str) -> str:
"""dev mode workaround for localhost testing
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
"""
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
@classmethod
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
url = (
f"https://slack.com/oauth/v2/authorize"
f"?client_id={cls.CLIENT_ID}"
f"&redirect_uri={redirect_uri}"
f"&scope={cls.BOT_SCOPE}"
f"&state={state}"
)
return url
@classmethod
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
"""Temporary state to store in redis. to be looked up on auth response.
Returns a json string.
"""
session = SlackOAuth.OAuthSession(
email=email, redirect_on_success=redirect_on_success
)
return session.model_dump_json()
@classmethod
def parse_session(cls, session_json: str) -> OAuthSession:
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
return session
@router.post("/connector/slack/callback")
def handle_slack_oauth_callback(
code: str,
state: str,
user: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> JSONResponse:
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
raise HTTPException(
status_code=500,
detail="Slack client ID or client secret is not configured.",
)
r = get_redis_client(tenant_id=tenant_id)
# recover the state
padded_state = state + "=" * (
-len(state) % 4
) # Add padding back (Base64 decoding requires padding)
uuid_bytes = base64.urlsafe_b64decode(
padded_state
) # Decode the Base64 string back to bytes
# Convert bytes back to a UUID
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
oauth_uuid_str = str(oauth_uuid)
r_key = f"da_oauth:{oauth_uuid_str}"
session_json_bytes = cast(bytes, r.get(r_key))
if not session_json_bytes:
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
)
session_json = session_json_bytes.decode("utf-8")
try:
session = SlackOAuth.parse_session(session_json)
if not DEV_MODE:
redirect_uri = SlackOAuth.REDIRECT_URI
else:
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
# Exchange the authorization code for an access token
response = requests.post(
SlackOAuth.TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"client_id": SlackOAuth.CLIENT_ID,
"client_secret": SlackOAuth.CLIENT_SECRET,
"code": code,
"redirect_uri": redirect_uri,
},
)
response_data = response.json()
if not response_data.get("ok"):
raise HTTPException(
status_code=400,
detail=f"Slack OAuth failed: {response_data.get('error')}",
)
# Extract token and team information
access_token: str = response_data.get("access_token")
team_id: str = response_data.get("team", {}).get("id")
authed_user_id: str = response_data.get("authed_user", {}).get("id")
credential_info = CredentialBase(
credential_json={"slack_bot_token": access_token},
admin_public=True,
source=DocumentSource.SLACK,
name="Slack OAuth",
)
create_credential(credential_info, user, db_session)
except Exception as e:
return JSONResponse(
status_code=500,
content={
"success": False,
"message": f"An error occurred during Slack OAuth: {str(e)}",
},
)
finally:
r.delete(r_key)
# return the result
return JSONResponse(
content={
"success": True,
"message": "Slack OAuth completed successfully.",
"finalize_url": None,
"redirect_on_success": session.redirect_on_success,
"team_id": team_id,
"authed_user_id": authed_user_id,
}
)

View File

@@ -13,7 +13,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.api_key import is_api_key_email_address
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
from onyx.db.models import TokenRateLimit
@@ -28,21 +28,21 @@ from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_glob
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def _check_token_rate_limits(user: User | None) -> None:
def _check_token_rate_limits(user: User | None, tenant_id: str) -> None:
if user is None:
# Unauthenticated users are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)
elif is_api_key_email_address(user.email):
# API keys are only rate limited by global settings
_user_is_rate_limited_by_global()
_user_is_rate_limited_by_global(tenant_id)
else:
run_functions_tuples_in_parallel(
[
(_user_is_rate_limited, (user.id,)),
(_user_is_rate_limited_by_group, (user.id,)),
(_user_is_rate_limited_by_global, ()),
(_user_is_rate_limited, (user.id, tenant_id)),
(_user_is_rate_limited_by_group, (user.id, tenant_id)),
(_user_is_rate_limited_by_global, (tenant_id,)),
]
)
@@ -52,8 +52,8 @@ User rate limits
"""
def _user_is_rate_limited(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
def _user_is_rate_limited(user_id: UUID, tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
user_rate_limits = fetch_all_user_token_rate_limits(
db_session=db_session, enabled_only=True, ordered=False
)
@@ -93,8 +93,8 @@ User Group rate limits
"""
def _user_is_rate_limited_by_group(user_id: UUID) -> None:
with get_session_with_current_tenant() as db_session:
def _user_is_rate_limited_by_group(user_id: UUID, tenant_id: str | None) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session)
if group_rate_limits:

View File

@@ -2,7 +2,6 @@ import csv
import io
from datetime import datetime
from datetime import timezone
from http import HTTPStatus
from uuid import UUID
from fastapi import APIRouter
@@ -22,10 +21,8 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_display_email
from onyx.chat.chat_utils import create_chat_chain
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.configs.constants import QueryHistoryType
from onyx.configs.constants import SessionType
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_chat_sessions_by_user
@@ -38,8 +35,6 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
router = APIRouter()
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
def fetch_and_process_chat_session_history(
db_session: Session,
@@ -112,17 +107,6 @@ def get_user_chat_sessions(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionsResponse:
# we specifically don't allow this endpoint if "anonymized" since
# this is a direct query on the user id
if ONYX_QUERY_HISTORY_TYPE in [
QueryHistoryType.DISABLED,
QueryHistoryType.ANONYMIZED,
]:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Per user query history has been disabled by the administrator.",
)
try:
chat_sessions = get_chat_sessions_by_user(
user_id=user_id, deleted=False, db_session=db_session, limit=0
@@ -138,7 +122,6 @@ def get_user_chat_sessions(
name=chat.description,
persona_id=chat.persona_id,
time_created=chat.time_created.isoformat(),
time_updated=chat.time_updated.isoformat(),
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
@@ -158,12 +141,6 @@ def get_chat_session_history(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> PaginatedReturn[ChatSessionMinimal]:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
page_of_chat_sessions = get_page_of_chat_sessions(
page_num=page_num,
page_size=page_size,
@@ -180,16 +157,11 @@ def get_chat_session_history(
feedback_filter=feedback_type,
)
minimal_chat_sessions: list[ChatSessionMinimal] = []
for chat_session in page_of_chat_sessions:
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
minimal_chat_sessions.append(minimal_chat_session)
return PaginatedReturn(
items=minimal_chat_sessions,
items=[
ChatSessionMinimal.from_chat_session(chat_session)
for chat_session in page_of_chat_sessions
],
total_items=total_filtered_chat_sessions_count,
)
@@ -200,12 +172,6 @@ def get_chat_session_admin(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> ChatSessionSnapshot:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
try:
chat_session = get_chat_session_by_id(
chat_session_id=chat_session_id,
@@ -227,9 +193,6 @@ def get_chat_session_admin(
f"Could not create snapshot for chat session with id '{chat_session_id}'",
)
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
return snapshot
@@ -240,12 +203,6 @@ def get_query_history_as_csv(
end: datetime | None = None,
db_session: Session = Depends(get_session),
) -> StreamingResponse:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
raise HTTPException(
status_code=HTTPStatus.FORBIDDEN,
detail="Query history has been disabled by the administrator.",
)
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
@@ -256,9 +213,6 @@ def get_query_history_as_csv(
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
for chat_session_snapshot in complete_chat_session_history:
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
question_answer_pairs.extend(
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
)

View File

@@ -7,7 +7,6 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.utils.logger import setup_logger
@@ -42,9 +41,7 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
return response.json()
def fetch_billing_information(
tenant_id: str,
) -> BillingInformation | SubscriptionStatusResponse:
def fetch_billing_information(tenant_id: str) -> BillingInformation:
logger.info("Fetching billing information")
token = generate_data_plane_token()
headers = {
@@ -55,19 +52,8 @@ def fetch_billing_information(
params = {"tenant_id": tenant_id}
response = requests.get(url, headers=headers, params=params)
response.raise_for_status()
response_data = response.json()
# Check if the response indicates no subscription
if (
isinstance(response_data, dict)
and "subscribed" in response_data
and not response_data["subscribed"]
):
return SubscriptionStatusResponse(**response_data)
# Otherwise, parse as BillingInformation
return BillingInformation(**response_data)
billing_info = BillingInformation(**response.json())
return billing_info
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:

View File

@@ -200,35 +200,14 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
def configure_default_api_keys(db_session: Session) -> None:
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
default_model_name="gpt-4",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
@@ -240,6 +219,25 @@ def configure_default_api_keys(db_session: Session) -> None:
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-5-sonnet-20241022",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,

View File

@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
def user_owns_a_tenant(email: str) -> bool:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
result = (
db_session.query(UserTenantMapping)
.filter(UserTenantMapping.email == email)
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
try:
mappings_to_delete = (
db_session.query(UserTenantMapping)
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
def remove_all_users_from_tenant(tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
with get_session_with_tenant(tenant_id=None) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()

View File

@@ -78,7 +78,7 @@ class CloudEmbedding:
self._closed = False
async def _embed_openai(
self, texts: list[str], model: str | None, reduced_dimension: int | None
self, texts: list[str], model: str | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
@@ -91,11 +91,7 @@ class CloudEmbedding:
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
response = await client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
@@ -227,10 +223,9 @@ class CloudEmbedding:
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
return await self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
@@ -331,7 +326,6 @@ async def embed_text(
prefix: str | None,
api_url: str | None,
api_version: str | None,
reduced_dimension: int | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
@@ -375,7 +369,6 @@ async def embed_text(
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
reduced_dimension=reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
@@ -515,7 +508,6 @@ async def process_embed_request(
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
reduced_dimension=embed_request.reduced_dimension,
prefix=prefix,
gpu_type=gpu_type,
)

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
from shared_configs.configs import MULTI_TENANT
_API_KEY_HEADER_NAME = "Authorization"
@@ -36,7 +35,8 @@ class ApiKeyDescriptor(BaseModel):
def generate_api_key(tenant_id: str | None = None) -> str:
if not MULTI_TENANT or not tenant_id:
# For backwards compatibility, if no tenant_id, generate old style key
if not tenant_id:
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
encoded_tenant = quote(tenant_id) # URL encode the tenant ID

View File

@@ -2,8 +2,6 @@ import smtplib
from datetime import datetime
from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText
from email.utils import formatdate
from email.utils import make_msgid
from onyx.configs.app_configs import EMAIL_CONFIGURED
from onyx.configs.app_configs import EMAIL_FROM
@@ -15,7 +13,6 @@ from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
from onyx.db.models import User
from shared_configs.configs import MULTI_TENANT
HTML_EMAIL_TEMPLATE = """\
<!DOCTYPE html>
@@ -153,9 +150,8 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")
if mail_from:
msg["From"] = mail_from
part_text = MIMEText(text_body, "plain")
part_html = MIMEText(html_body, "html")
@@ -177,7 +173,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
subject = "Your Onyx Subscription Has Been Canceled"
heading = "Subscription Canceled"
message = (
"<p>We're sorry to see you go.</p>"
"<p>Were sorry to see you go.</p>"
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
"<p>If you change your mind, you can always come back!</p>"
)
@@ -243,13 +239,13 @@ def send_user_email_invite(
def send_forgot_password_email(
user_email: str,
token: str,
tenant_id: str,
mail_from: str = EMAIL_FROM,
tenant_id: str | None = None,
) -> None:
# Builds a forgot password email with or without fancy HTML
subject = "Onyx Forgot Password"
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
if MULTI_TENANT:
if tenant_id:
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
html_content = build_html_email("Reset Your Password", message)

View File

@@ -214,7 +214,7 @@ def verify_email_is_invited(email: str) -> None:
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if not get_user_by_email(email, db_session):
verify_email_is_invited(email)
@@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"refresh_token": refresh_token,
}
user: User | None = None
user: User
try:
# Attempt to get user by OAuth account
@@ -420,20 +420,15 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
except exceptions.UserNotExists:
try:
# Attempt to get user by email
user = await self.user_db.get_by_email(account_email)
user = await self.get_by_email(account_email)
if not associate_by_email:
raise exceptions.UserAlreadyExists()
# Make sure user is not None before adding OAuth account
if user is not None:
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
else:
# This shouldn't happen since get_by_email would raise UserNotExists
# but adding as a safeguard
raise exceptions.UserNotExists()
user = await self.user_db.add_oauth_account(
user, oauth_account_dict
)
# If user not found by OAuth account or email, create a new user
except exceptions.UserNotExists:
password = self.password_helper.generate()
user_dict = {
@@ -444,36 +439,26 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = await self.user_db.create(user_dict)
# Add OAuth account only if user creation was successful
if user is not None:
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
raise HTTPException(
status_code=500, detail="Failed to create user account"
)
# Explicitly set the Postgres schema for this session to ensure
# OAuth account creation happens in the correct tenant schema
# Add OAuth account
await self.user_db.add_oauth_account(user, oauth_account_dict)
await self.on_after_register(user, request)
else:
# User exists, update OAuth account if needed
if user is not None: # Add explicit check
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# Ensure user is not None before proceeding
if user is None:
raise HTTPException(
status_code=500, detail="Failed to authenticate or create user"
)
for existing_oauth_account in user.oauth_accounts:
if (
existing_oauth_account.account_id == account_id
and existing_oauth_account.oauth_name == oauth_name
):
user = await self.user_db.update_oauth_account(
user,
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
# but the type checker doesn't know that :(
existing_oauth_account, # type: ignore
oauth_account_dict,
)
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
# re-authenticate that frequently, so by default this is disabled
@@ -568,7 +553,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
async_return_default_schema,
)(email=user.email)
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
async def on_after_request_verify(
self, user: User, token: str, request: Optional[Request] = None

View File

@@ -2,7 +2,6 @@ import logging
import multiprocessing
import time
from typing import Any
from typing import cast
import sentry_sdk
from celery import Task
@@ -132,9 +131,9 @@ def on_task_postrun(
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
if not kwargs:
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
tenant_id = POSTGRES_DEFAULT_SCHEMA
tenant_id = None
else:
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
tenant_id = kwargs.get("tenant_id")
task_logger.debug(
f"Task {task.name} (ID: {task_id}) completed with state: {state} "

View File

@@ -92,8 +92,7 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue and return them
as a set.
"""This is a redis specific way to build a list of tasks in a queue.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.

View File

@@ -34,7 +34,7 @@ def _get_deletion_status(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
tenant_id: str | None = None,
) -> TaskQueueState | None:
"""We no longer store TaskQueueState in the DB for a deletion attempt.
This function populates TaskQueueState by just checking redis.
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
connector_id: int,
credential_id: int,
db_session: Session,
tenant_id: str,
tenant_id: str | None = None,
) -> DeletionAttemptSnapshot | None:
deletion_task = _get_deletion_status(
connector_id, credential_id, db_session, tenant_id

View File

@@ -8,21 +8,16 @@ from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import fetch_connector_by_id
from onyx.db.connector_credential_pair import add_deletion_failure_message
from onyx.db.connector_credential_pair import (
@@ -57,51 +52,6 @@ class TaskDependencyError(RuntimeError):
with connector deletion."""
def revoke_tasks_blocking_deletion(
redis_connector: RedisConnector, db_session: Session, app: Celery
) -> None:
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(search_settings.id)
try:
index_payload = redis_connector_index.payload
if index_payload and index_payload.celery_task_id:
app.control.revoke(index_payload.celery_task_id)
task_logger.info(
f"Revoked indexing task {index_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking indexing task")
try:
permissions_sync_payload = redis_connector.permissions.payload
if permissions_sync_payload and permissions_sync_payload.celery_task_id:
app.control.revoke(permissions_sync_payload.celery_task_id)
task_logger.info(
f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking pruning task")
try:
prune_payload = redis_connector.prune.payload
if prune_payload and prune_payload.celery_task_id:
app.control.revoke(prune_payload.celery_task_id)
task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.")
except Exception:
task_logger.exception("Exception while revoking permissions sync task")
try:
external_group_sync_payload = redis_connector.external_group_sync.payload
if external_group_sync_payload and external_group_sync_payload.celery_task_id:
app.control.revoke(external_group_sync_payload.celery_task_id)
task_logger.info(
f"Revoked external group sync task {external_group_sync_payload.celery_task_id}."
)
except Exception:
task_logger.exception("Exception while revoking external group sync task")
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
ignore_result=True,
@@ -109,36 +59,22 @@ def revoke_tasks_blocking_deletion(
trail=False,
bind=True,
)
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
def check_for_connector_deletion_task(
self: Task, *, tenant_id: str | None
) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
# Prevent this task from overlapping with itself
# these tasks should never overlap
if not lock_beat.acquire(blocking=False):
return None
try:
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES):
# clear fences that don't have associated celery tasks in progress
try:
validate_connector_deletion_fences(
tenant_id, r, r_replica, r_celery, lock_beat
)
except Exception:
task_logger.exception(
"Exception while validating connector deletion fences"
)
r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300)
# collect cc_pair_ids
cc_pair_ids: list[int] = []
with get_session_with_current_tenant() as db_session:
@@ -156,38 +92,9 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
)
except TaskDependencyError as e:
# this means we wanted to start deleting but dependent tasks were running
# on the first error, we set a stop signal and revoke the dependent tasks
# on subsequent errors, we hard reset blocking fences after our specified timeout
# is exceeded
# Leave a stop signal to clear indexing and pruning tasks more quickly
task_logger.info(str(e))
if not redis_connector.stop.fenced:
# one time revoke of celery tasks
task_logger.info("Revoking any tasks blocking deletion.")
revoke_tasks_blocking_deletion(
redis_connector, db_session, self.app
)
redis_connector.stop.set_fence(True)
redis_connector.stop.set_timeout()
else:
# stop signal already set
if redis_connector.stop.timed_out:
# waiting too long, just reset blocking fences
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()
else:
# just wait
pass
redis_connector.stop.set_fence(True)
else:
# clear the stop signal if it exists ... no longer needed
redis_connector.stop.set_fence(False)
@@ -222,7 +129,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
cc_pair_id: int,
db_session: Session,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Note that syncing can still be required even if the number of sync tasks generated is zero.
@@ -262,7 +169,6 @@ def try_generate_document_cc_pair_cleanup_tasks(
return None
# set a basic fence to start
redis_connector.delete.set_active()
fence_payload = RedisConnectorDeletePayload(
num_tasks=None,
submitted=datetime.now(timezone.utc),
@@ -343,7 +249,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
def monitor_connector_deletion_taskset(
tenant_id: str, key_bytes: bytes, r: Redis
tenant_id: str | None, key_bytes: bytes, r: Redis
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -495,171 +401,3 @@ def monitor_connector_deletion_taskset(
)
redis_connector.delete.reset()
def validate_connector_deletion_fences(
tenant_id: str,
r: Redis,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery)
if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.CONNECTOR_DELETION, r_celery
)
# validate all existing connector deletion jobs
lock_beat.reacquire()
keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES))
for key in keys:
key_bytes = cast(bytes, key)
key_str = key_bytes.decode("utf-8")
if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX):
continue
validate_connector_deletion_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
r,
)
lock_beat.reacquire()
return
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_connector_deletion_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.delete.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.delete.payload
except ValidationError:
task_logger.exception(
"validate_connector_deletion_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return
if not payload:
return
# OK, there's actually something for us to validate
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.delete.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_connector_deletion_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
# we're active if there are still tasks to run and those tasks all exist in celery
if tasks_scanned > 0 and tasks_not_in_celery == 0:
redis_connector.delete.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.delete.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_connector_deletion_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.delete.reset()
return

View File

@@ -30,7 +30,6 @@ from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -43,10 +42,8 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
@@ -66,7 +63,6 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
@@ -197,19 +193,12 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None
monitor_ccpair_permissions_taskset(
tenant_id, key_bytes, r, db_session
)
task_logger.info(f"check_for_doc_permissions_sync finished: tenant={tenant_id}")
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id} {error_msg}"
)
task_logger.exception(
f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id}"
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
@@ -221,7 +210,7 @@ def try_creating_permissions_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Returns a randomized payload id on success.
Returns None if no syncing is required."""
@@ -293,19 +282,13 @@ def try_creating_permissions_sync_task(
redis_connector.permissions.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_permissions_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_permissions_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -320,7 +303,7 @@ def try_creating_permissions_sync_task(
def connector_permission_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles document permission syncing for a given connector credential pair
@@ -405,29 +388,6 @@ def connector_permission_sync_generator_task(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
doc_sync_func = DOC_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -479,10 +439,6 @@ def connector_permission_sync_generator_task(
redis_connector.permissions.generator_complete = tasks_generated
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}"
)
task_logger.exception(
f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}"
)
@@ -509,7 +465,7 @@ def connector_permission_sync_generator_task(
)
def update_external_document_permissions_task(
self: Task,
tenant_id: str,
tenant_id: str | None,
serialized_doc_external_access: dict,
source_string: str,
connector_id: int,
@@ -517,8 +473,6 @@ def update_external_document_permissions_task(
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@@ -558,33 +512,18 @@ def update_external_document_permissions_task(
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Exception in update_external_document_permissions_task: connector_id={connector_id} doc_id={doc_id} {error_msg}"
)
except Exception:
task_logger.exception(
f"update_external_document_permissions_task exceptioned: "
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} doc_id={doc_id}"
)
completion_status = OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
finally:
task_logger.info(
f"update_external_document_permissions_task completed: status={completion_status.value} doc={doc_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
task_logger.info(
f"update_external_document_permissions_task finished: connector_id={connector_id} doc_id={doc_id}"
)
return True
def validate_permission_sync_fences(
tenant_id: str,
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -631,7 +570,7 @@ def validate_permission_sync_fences(
def validate_permission_sync_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
@@ -841,7 +780,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
def monitor_ccpair_permissions_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)

View File

@@ -37,11 +37,8 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_external_group_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
@@ -58,7 +55,6 @@ from onyx.redis.redis_connector_ext_group_sync import (
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -123,7 +119,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client()
@@ -152,10 +148,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
# These are ordered by cc_pair id so the first one is the one we want
cc_pairs_to_dedupe = get_cc_pairs_by_source(
db_session,
source,
access_type=AccessType.SYNC,
status=ConnectorCredentialPairStatus.ACTIVE,
db_session, source, only_sync=True
)
# We only want to sync one cc_pair per source type
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
@@ -202,17 +195,12 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected check_for_external_group_sync exception: tenant={tenant_id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: tenant={tenant_id}")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_external_group_sync finished: tenant={tenant_id}")
return True
@@ -220,7 +208,7 @@ def try_creating_external_group_sync_task(
app: Celery,
cc_pair_id: int,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
Returns None if no syncing is required."""
@@ -279,19 +267,12 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_external_group_sync_task exception: cc_pair={cc_pair_id} {error_msg}"
)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
)
return None
task_logger.info(
f"try_creating_external_group_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}"
)
return payload_id
@@ -306,7 +287,7 @@ def try_creating_external_group_sync_task(
def connector_external_group_sync_generator_task(
self: Task,
cc_pair_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""
External group sync task for a given connector credential pair
@@ -387,29 +368,6 @@ def connector_external_group_sync_generator_task(
f"No connector credential pair found for id: {cc_pair_id}"
)
try:
created = validate_ccpair_for_user(
cc_pair.connector.id,
cc_pair.credential.id,
db_session,
enforce_creation=False,
)
if not created:
task_logger.warning(
f"Unable to create connector credential pair for id: {cc_pair_id}"
)
except Exception:
task_logger.exception(
f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}"
)
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise
source_type = cc_pair.connector.source
ext_group_sync_func = GROUP_PERMISSIONS_FUNC_MAP.get(source_type)
@@ -421,18 +379,8 @@ def connector_external_group_sync_generator_task(
logger.info(
f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}"
)
external_user_groups: list[ExternalUserGroup] = []
try:
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
except ConnectorValidationError as e:
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
update_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
status=ConnectorCredentialPairStatus.INVALID,
)
raise e
external_user_groups: list[ExternalUserGroup] = ext_group_sync_func(cc_pair)
logger.info(
f"Syncing {len(external_user_groups)} external user groups for {source_type}"
@@ -458,14 +406,6 @@ def connector_external_group_sync_generator_task(
sync_status=SyncStatus.SUCCESS,
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id} {error_msg}"
)
task_logger.exception(
f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
)
msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}"
task_logger.exception(msg)
emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id)
@@ -493,7 +433,7 @@ def connector_external_group_sync_generator_task(
def validate_external_group_sync_fences(
tenant_id: str,
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_replica: Redis,
@@ -525,7 +465,7 @@ def validate_external_group_sync_fences(
def validate_external_group_sync_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,

View File

@@ -23,9 +23,9 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.indexing.utils import _should_index
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
@@ -48,7 +48,7 @@ from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisConstants
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.db.connector import mark_ccpair_with_indexing_trigger
from onyx.db.connector_credential_pair import fetch_connector_credential_pairs
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
@@ -61,7 +61,7 @@ from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.db.swap_index import check_index_swap
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.redis.redis_connector import RedisConnector
@@ -182,7 +182,7 @@ class SimpleJobResult:
class ConnectorIndexingContext(BaseModel):
tenant_id: str
tenant_id: str | None
cc_pair_id: int
search_settings_id: int
index_attempt_id: int
@@ -210,7 +210,7 @@ class ConnectorIndexingLogBuilder:
def monitor_ccpair_indexing_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
@@ -358,7 +358,7 @@ def monitor_ccpair_indexing_taskset(
soft_time_limit=300,
bind=True,
)
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
"""a lightweight task used to kick off indexing tasks.
Occcasionally does some validation of existing state to clear up error conditions"""
@@ -406,7 +406,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# check for search settings swap
with get_session_with_current_tenant() as db_session:
old_search_settings = check_and_perform_index_swap(db_session=db_session)
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
@@ -439,15 +439,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
with get_session_with_current_tenant() as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
# skip non-live search settings that don't have background reindex enabled
# those should just auto-change to live shortly after creation without
# requiring any indexing till that point
if (
not search_settings_instance.status.is_current()
and not search_settings_instance.background_reindex_enabled
):
continue
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -465,18 +456,23 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair.id, search_settings_instance.id, db_session
)
if not should_index(
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.status.is_current():
# the indexing trigger is only checked and cleared with the current search settings
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
@@ -602,7 +598,7 @@ def connector_indexing_task(
cc_pair_id: int,
search_settings_id: int,
is_ee: bool,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -894,7 +890,7 @@ def connector_indexing_proxy_task(
index_attempt_id: int,
cc_pair_id: int,
search_settings_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
and forking is inherently unstable.
@@ -903,9 +899,6 @@ def connector_indexing_proxy_task(
TODO(rkuo): refactor this so that there is a single return path where we canonically
log the result of running this function.
NOTE: we try/except all db access in this function because as a watchdog, this function
needs to be extremely stable.
"""
start = time.monotonic()
@@ -931,7 +924,6 @@ def connector_indexing_proxy_task(
task_logger.error("self.request.id is None!")
client = SimpleJobClient()
task_logger.info(f"submitting connector_indexing_task with tenant_id={tenant_id}")
job = client.submit(
connector_indexing_task,
@@ -1024,7 +1016,7 @@ def connector_indexing_proxy_task(
job.release()
break
# if a termination signal is detected, break (exit point will clean up)
# if a termination signal is detected, clean up and break
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
log_builder.build("Indexing watchdog - termination signal detected")
@@ -1033,7 +1025,6 @@ def connector_indexing_proxy_task(
result.status = IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL
break
# if activity timeout is detected, break (exit point will clean up)
if not redis_connector_index.connector_active():
task_logger.warning(
log_builder.build(
@@ -1042,6 +1033,25 @@ def connector_indexing_proxy_task(
)
)
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
result.status = (
IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT
)
@@ -1060,15 +1070,15 @@ def connector_indexing_proxy_task(
if not index_attempt.is_finished():
continue
except Exception:
# if the DB exceptioned, just restart the check.
# polling the index attempt status doesn't need to be strongly consistent
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception looking up index attempt"
)
)
continue
except Exception as e:
result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED
if isinstance(e, ConnectorValidationError):
@@ -1129,6 +1139,8 @@ def connector_indexing_proxy_task(
"Connector termination signal detected",
)
except Exception:
# if the DB exceptions, we'll just get an unfriendly failure message
# in the UI instead of the cancellation message
task_logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as canceled"
@@ -1136,25 +1148,6 @@ def connector_indexing_proxy_task(
)
job.cancel()
elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT:
try:
with get_session_with_current_tenant() as db_session:
mark_attempt_failed(
index_attempt_id,
db_session,
"Indexing watchdog - activity timeout exceeded: "
f"attempt={index_attempt_id} "
f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s",
)
except Exception:
logger.exception(
log_builder.build(
"Indexing watchdog - transient exception marking index attempt as failed"
)
)
job.cancel()
else:
pass
task_logger.info(
log_builder.build(
@@ -1174,7 +1167,7 @@ def connector_indexing_proxy_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
)
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
"""Clean up old checkpoints that are older than 7 days."""
locked = False
redis_client = get_redis_client(tenant_id=tenant_id)

View File

@@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase):
def validate_indexing_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
@@ -311,7 +311,7 @@ def validate_indexing_fence(
def validate_indexing_fences(
tenant_id: str,
tenant_id: str | None,
r_replica: Redis,
r_celery: Redis,
lock_beat: RedisLock,
@@ -346,10 +346,11 @@ def validate_indexing_fences(
return
def should_index(
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -414,9 +415,9 @@ def should_index(
):
return False
if search_settings_instance.status.is_current():
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for live search settings
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq
@@ -441,7 +442,7 @@ def try_creating_indexing_task(
reindex: bool,
db_session: Session,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
"""Checks for any conditions that should block the indexing task from being
created, then creates the task.

View File

@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
trail=False,
bind=True,
)
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
if not LLM_MODEL_UPDATE_API_URL:
raise ValueError("LLM model update API URL not configured")

View File

@@ -91,7 +91,7 @@ class Metric(BaseModel):
}
task_logger.info(json.dumps(data))
def emit(self, tenant_id: str) -> None:
def emit(self, tenant_id: str | None) -> None:
# Convert value to appropriate type based on the input value
bool_value = None
float_value = None
@@ -656,7 +656,7 @@ def build_job_id(
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
"""Collect and emit metrics about background processes.
This task runs periodically to gather metrics about:
- Queue lengths for different Celery queues
@@ -864,7 +864,7 @@ def cloud_monitor_celery_queues(
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
return monitor_celery_queues_helper(self)

View File

@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
bind=True,
base=AbortableTask,
)
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
"""Runs periodically to clean up the kombu_message table"""
# we will select messages older than this amount to clean up

View File

@@ -55,7 +55,6 @@ from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.server.utils import make_short_id
from onyx.utils.logger import format_error_for_logging
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@@ -114,7 +113,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
soft_time_limit=JOB_TIMEOUT,
bind=True,
)
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client()
r_replica = get_redis_replica_client()
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
@@ -195,14 +194,12 @@ def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(f"Unexpected pruning check exception: {error_msg}")
except Exception:
task_logger.exception("Unexpected exception during pruning check")
finally:
if lock_beat.owned():
lock_beat.release()
task_logger.info(f"check_for_pruning finished: tenant={tenant_id}")
return True
@@ -211,7 +208,7 @@ def try_creating_prune_generator_task(
cc_pair: ConnectorCredentialPair,
db_session: Session,
r: Redis,
tenant_id: str,
tenant_id: str | None,
) -> str | None:
"""Checks for any conditions that should block the pruning generator task from being
created, then creates the task.
@@ -304,19 +301,13 @@ def try_creating_prune_generator_task(
redis_connector.prune.set_fence(payload)
payload_id = payload.id
except Exception as e:
error_msg = format_error_for_logging(e)
task_logger.warning(
f"Unexpected try_creating_prune_generator_task exception: cc_pair={cc_pair.id} {error_msg}"
)
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}")
return None
finally:
if lock.owned():
lock.release()
task_logger.info(
f"try_creating_prune_generator_task finished: cc_pair={cc_pair.id} payload_id={payload_id}"
)
return payload_id
@@ -333,7 +324,7 @@ def connector_pruning_generator_task(
cc_pair_id: int,
connector_id: int,
credential_id: int,
tenant_id: str,
tenant_id: str | None,
) -> None:
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
@@ -521,7 +512,7 @@ def connector_pruning_generator_task(
def monitor_ccpair_pruning_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
@@ -567,7 +558,7 @@ def monitor_ccpair_pruning_taskset(
def validate_pruning_fences(
tenant_id: str,
tenant_id: str | None,
r: Redis,
r_replica: Redis,
r_celery: Redis,
@@ -615,7 +606,7 @@ def validate_pruning_fences(
def validate_pruning_fence(
tenant_id: str,
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
queued_tasks: set[str],

View File

@@ -32,7 +32,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
) -> int:
return self.index.delete_single(
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:

View File

@@ -1,5 +1,4 @@
import time
from enum import Enum
from http import HTTPStatus
import httpx
@@ -46,24 +45,6 @@ LIGHT_SOFT_TIME_LIMIT = 105
LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
class OnyxCeleryTaskCompletionStatus(str, Enum):
"""The different statuses the watchdog can finish with.
TODO: create broader success/failure/abort categories
"""
UNDEFINED = "undefined"
SUCCEEDED = "succeeded"
SKIPPED = "skipped"
SOFT_TIME_LIMIT = "soft_time_limit"
NON_RETRYABLE_EXCEPTION = "non_retryable_exception"
RETRYABLE_EXCEPTION = "retryable_exception"
@shared_task(
name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
@@ -76,7 +57,7 @@ def document_by_cc_pair_cleanup_task(
document_id: str,
connector_id: int,
credential_id: int,
tenant_id: str,
tenant_id: str | None,
) -> bool:
"""A lightweight subtask used to clean up document to cc pair relationships.
Created by connection deletion and connector pruning parent tasks."""
@@ -97,8 +78,6 @@ def document_by_cc_pair_cleanup_task(
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
action = "skip"
@@ -131,9 +110,6 @@ def document_by_cc_pair_cleanup_task(
db_session=db_session,
document_ids=[document_id],
)
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
elif count > 1:
action = "update"
@@ -177,11 +153,10 @@ def document_by_cc_pair_cleanup_task(
)
mark_document_as_synced(document_id, db_session)
db_session.commit()
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
else:
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
pass
db_session.commit()
elapsed = time.monotonic() - start
task_logger.info(
@@ -193,79 +168,57 @@ def document_by_cc_pair_cleanup_task(
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
return False
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"document_by_cc_pair_cleanup_task exceptioned: doc={document_id}"
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(f"Unexpected exception: doc={document_id}")
if self.request.retries < DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES:
# Still retrying. Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"document_by_cc_pair_cleanup_task completed: status={completion_status.value} doc={document_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
self.retry(exc=e, countdown=countdown)
else:
# This is the last attempt! mark the document as dirty in the db so that it
# eventually gets fixed out of band via stale document reconciliation
task_logger.warning(
f"Max celery task retries reached. Marking doc as dirty for reconciliation: "
f"doc={document_id}"
)
with get_session_with_current_tenant() as db_session:
# delete the cc pair relationship now and let reconciliation clean it up
# in vespa
delete_document_by_connector_credential_pair__no_commit(
db_session=db_session,
document_id=document_id,
connector_credential_pair_identifier=ConnectorCredentialPairIdentifier(
connector_id=connector_id,
credential_id=credential_id,
),
)
mark_document_as_modified(document_id, db_session)
return False
task_logger.info(f"document_by_cc_pair_cleanup_task finished: doc={document_id}")
return True
@@ -297,8 +250,7 @@ def cloud_beat_task_generator(
return None
last_lock_time = time.monotonic()
tenant_ids: list[str] = []
num_processed_tenants = 0
tenant_ids: list[str] | list[None] = []
try:
tenant_ids = get_all_tenant_ids()
@@ -326,8 +278,6 @@ def cloud_beat_task_generator(
expires=expires,
ignore_result=True,
)
num_processed_tenants += 1
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -347,7 +297,6 @@ def cloud_beat_task_generator(
task_logger.info(
f"cloud_beat_task_generator finished: "
f"task={task_name} "
f"num_processed_tenants={num_processed_tenants} "
f"num_tenants={len(tenant_ids)} "
f"elapsed={time_elapsed:.2f}"
)

View File

@@ -19,7 +19,6 @@ from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS
from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
@@ -76,7 +75,7 @@ logger = setup_logger()
trail=False,
bind=True,
)
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
"""Runs periodically to check if any document needs syncing.
Generates sets of tasks for Celery if syncing is needed."""
@@ -208,7 +207,7 @@ def try_generate_stale_document_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
# the fence is up, do nothing
@@ -284,7 +283,7 @@ def try_generate_document_set_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -361,7 +360,7 @@ def try_generate_user_group_sync_tasks(
db_session: Session,
r: Redis,
lock_beat: RedisLock,
tenant_id: str,
tenant_id: str | None,
) -> int | None:
lock_beat.reacquire()
@@ -448,7 +447,7 @@ def monitor_connector_taskset(r: Redis) -> None:
def monitor_document_set_taskset(
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
) -> None:
fence_key = key_bytes.decode("utf-8")
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
@@ -523,11 +522,11 @@ def monitor_document_set_taskset(
time_limit=LIGHT_TIME_LIMIT,
max_retries=3,
)
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
def vespa_metadata_sync_task(
self: Task, document_id: str, *, tenant_id: str | None
) -> bool:
start = time.monotonic()
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
try:
with get_session_with_current_tenant() as db_session:
active_search_settings = get_active_search_settings(db_session)
@@ -541,103 +540,75 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
doc = get_document(document_id, db_session)
if not doc:
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=no_operation "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED
else:
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
return False
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
# document set sync
doc_sets = fetch_document_sets_for_document(document_id, db_session)
update_doc_sets: set[str] = set(doc_sets)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT
except Exception as ex:
e: Exception | None = None
while True:
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
break
task_logger.exception(
f"vespa_metadata_sync_task exceptioned: doc={document_id}"
# User group sync
doc_access = get_access_for_document(
document_id=document_id, db_session=db_session
)
completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION
if (
self.max_retries is not None
and self.request.retries >= self.max_retries
):
completion_status = (
OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION
)
fields = VespaDocumentFields(
document_sets=update_doc_sets,
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
)
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown) # this will raise a celery exception
break # we won't hit this, but it looks weird not to have it
finally:
task_logger.info(
f"vespa_metadata_sync_task completed: status={completion_status.value} doc={document_id}"
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
chunks_affected = retry_index.update_single(
document_id,
tenant_id=tenant_id,
chunk_count=doc.chunk_count,
fields=fields,
)
# update db last. Worst case = we crash right before this and
# the sync might repeat again later
mark_document_as_synced(document_id, db_session)
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action=sync "
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
e: Exception | None = None
if isinstance(ex, RetryError):
task_logger.warning(
f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}"
)
# only set the inner exception if it is of type Exception
e_temp = ex.last_attempt.exception()
if isinstance(e_temp, Exception):
e = e_temp
else:
e = ex
if isinstance(e, httpx.HTTPStatusError):
if e.response.status_code == HTTPStatus.BAD_REQUEST:
task_logger.exception(
f"Non-retryable HTTPStatusError: "
f"doc={document_id} "
f"status={e.response.status_code}"
)
return False
task_logger.exception(
f"Unexpected exception during vespa metadata sync: doc={document_id}"
)
if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED:
return False
# Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64
countdown = 2 ** (self.request.retries + 4)
self.retry(exc=e, countdown=countdown)
return True

View File

@@ -1,5 +1,3 @@
from sqlalchemy.exc import IntegrityError
from onyx.db.background_error import create_background_error
from onyx.db.engine import get_session_with_current_tenant
@@ -11,27 +9,5 @@ def emit_background_error(
"""Currently just saves a row in the background_errors table.
In the future, could create notifications based on the severity."""
error_message = ""
# try to write to the db, but handle IntegrityError specifically
try:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)
except IntegrityError as e:
# Log an error if the cc_pair_id was deleted or any other exception occurs
error_message = (
f"Failed to create background error: {str(e)}. Original message: {message}"
)
except Exception:
pass
if not error_message:
return
# if we get here from an IntegrityError, try to write the error message to the db
# we need a new session because the first session is now invalid
try:
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, error_message, None)
except Exception:
pass
with get_session_with_current_tenant() as db_session:
create_background_error(db_session, message, cc_pair_id)

View File

@@ -16,10 +16,7 @@ from typing import Optional
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
from onyx.db.engine import SqlEngine
from onyx.setup import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -57,15 +54,6 @@ def _initializer(
kwargs = {}
logger.info("Initializing spawned worker child process.")
# 1. Get tenant_id from args or fallback to default
tenant_id = POSTGRES_DEFAULT_SCHEMA
for arg in reversed(args):
if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX):
tenant_id = arg
break
# 2. Set the tenant context before running anything
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Reset the engine in the child process
SqlEngine.reset_engine()
@@ -93,8 +81,6 @@ def _initializer(
queue.put(error_msg) # Send the exception to the parent process
sys.exit(255) # use 255 to indicate a generic exception
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _run_in_process(

View File

@@ -21,9 +21,8 @@ from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
@@ -56,7 +55,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
@@ -69,6 +67,7 @@ def _get_connector_runner(
batch_size: int,
start_time: datetime,
end_time: datetime,
tenant_id: str | None,
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
) -> ConnectorRunner:
"""
@@ -87,23 +86,18 @@ def _get_connector_runner(
input_type=task,
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
credential=attempt.connector_credential_pair.credential,
tenant_id=tenant_id,
)
# validate the connector settings
if not INTEGRATION_TESTS_MODE:
runnable_connector.validate_connector_settings()
except UnexpectedValidationError as e:
logger.exception(
"Unable to instantiate connector due to an unexpected temporary issue."
)
raise e
except Exception as e:
logger.exception("Unable to instantiate connector. Pausing until fixed.")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed
logger.exception(f"Unable to instantiate connector due to {e}")
# Sometimes there are cases where the connector will
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed. Sometimes there are cases where the connector will
# intermittently fail to initialize in which case we should pass in
# leave_connector_active=True to allow it to continue.
# For example, if there is nightly maintenance on a Confluence Server instance,
@@ -247,7 +241,7 @@ def _check_failure_threshold(
def _run_indexing(
db_session: Session,
index_attempt_id: int,
tenant_id: str,
tenant_id: str | None,
callback: IndexingHeartbeatInterface | None = None,
) -> None:
"""
@@ -394,6 +388,7 @@ def _run_indexing(
batch_size=INDEX_BATCH_SIZE,
start_time=window_start,
end_time=window_end,
tenant_id=tenant_id,
)
# don't use a checkpoint if we're explicitly indexing from
@@ -686,7 +681,7 @@ def _run_indexing(
def run_indexing_entrypoint(
index_attempt_id: int,
tenant_id: str,
tenant_id: str | None,
connector_credential_pair_id: int,
is_ee: bool = False,
callback: IndexingHeartbeatInterface | None = None,
@@ -706,7 +701,7 @@ def run_indexing_entrypoint(
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
tenant_str = ""
if MULTI_TENANT:
if tenant_id is not None:
tenant_str = f" for tenant {tenant_id}"
connector_name = attempt.connector_credential_pair.connector.name

View File

@@ -747,13 +747,14 @@ def stream_chat_message_objects(
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config, llm.config),
system_message=default_build_system_message(prompt_config),
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,
raw_user_uploaded_files=latest_query_files or [],
single_message_history=single_message_history,
)
prompt_builder.update_system_prompt(default_build_system_message(prompt_config))
# LLM prompt building, response capturing, etc.
answer = Answer(
@@ -869,6 +870,7 @@ def stream_chat_message_objects(
for img in img_generation_response
if img.image_data
],
tenant_id=tenant_id,
)
info.ai_message_files.extend(
[

View File

@@ -12,7 +12,6 @@ from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_toke
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLMConfig
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.llm.models import PreviousMessage
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_message_tokens
@@ -20,7 +19,6 @@ from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.llm.utils import model_supports_image_input
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
@@ -33,16 +31,8 @@ from onyx.tools.tool import Tool
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
# for o-series markdown generation
if (
llm_config.model_provider == OPENAI_PROVIDER_NAME
and llm_config.model_name.startswith("o")
):
system_prompt = CODE_BLOCK_MARKDOWN + system_prompt
tag_handled_prompt = handle_onyx_date_awareness(
system_prompt,
prompt_config,
@@ -120,8 +110,21 @@ class AnswerPromptBuilder:
),
)
self.update_system_prompt(system_message)
self.update_user_prompt(user_message)
self.system_message_and_token_cnt: tuple[SystemMessage, int] | None = (
(
system_message,
check_message_tokens(system_message, self.llm_tokenizer_encode_func),
)
if system_message
else None
)
self.user_message_and_token_cnt = (
user_message,
check_message_tokens(
user_message,
self.llm_tokenizer_encode_func,
),
)
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []

View File

@@ -6,7 +6,6 @@ from typing import cast
from onyx.auth.schemas import AuthBackend
from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
#####
@@ -30,9 +29,6 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
) # 1 day
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
)
#####
# Web Configs
@@ -162,7 +158,7 @@ POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres"
POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2"

View File

@@ -213,12 +213,6 @@ class AuthType(str, Enum):
CLOUD = "cloud"
class QueryHistoryType(str, Enum):
DISABLED = "disabled"
ANONYMIZED = "anonymized"
NORMAL = "normal"
# Special characters for password validation
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
@@ -348,9 +342,6 @@ class OnyxRedisSignals:
BLOCK_PRUNING = "signal:block_pruning"
BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences"
BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table"
BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = (
"signal:block_validate_connector_deletion_fences"
)
class OnyxRedisConstants:

View File

@@ -7,18 +7,11 @@ from typing import Optional
import boto3 # type: ignore
from botocore.client import Config # type: ignore
from botocore.exceptions import ClientError
from botocore.exceptions import NoCredentialsError
from botocore.exceptions import PartialCredentialsError
from mypy_boto3_s3 import S3Client # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import BlobType
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -247,73 +240,6 @@ class BlobStorageConnector(LoadConnector, PollConnector):
return None
def validate_connector_settings(self) -> None:
if self.s3_client is None:
raise ConnectorMissingCredentialError(
"Blob storage credentials not loaded."
)
if not self.bucket_name:
raise ConnectorValidationError(
"No bucket name was provided in connector settings."
)
try:
# We only fetch one object/page as a light-weight validation step.
# This ensures we trigger typical S3 permission checks (ListObjectsV2, etc.).
self.s3_client.list_objects_v2(
Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1
)
except NoCredentialsError:
raise ConnectorMissingCredentialError(
"No valid blob storage credentials found or provided to boto3."
)
except PartialCredentialsError:
raise ConnectorMissingCredentialError(
"Partial or incomplete blob storage credentials provided to boto3."
)
except ClientError as e:
error_code = e.response["Error"].get("Code", "")
status_code = e.response["ResponseMetadata"].get("HTTPStatusCode")
# Most common S3 error cases
if error_code in [
"AccessDenied",
"InvalidAccessKeyId",
"SignatureDoesNotMatch",
]:
if status_code == 403 or error_code == "AccessDenied":
raise InsufficientPermissionsError(
f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. "
"Please check your bucket policy and/or IAM policy."
)
if status_code == 401 or error_code == "SignatureDoesNotMatch":
raise CredentialExpiredError(
"Provided blob storage credentials appear invalid or expired."
)
raise CredentialExpiredError(
f"Credential issue encountered ({error_code})."
)
if error_code == "NoSuchBucket" or status_code == 404:
raise ConnectorValidationError(
f"Bucket '{self.bucket_name}' does not exist or cannot be found."
)
raise ConnectorValidationError(
f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}"
)
except Exception as e:
# Catch-all for anything not captured by the above
# Since we are unsure of the error and it may not disable the connector,
# raise an unexpected error (does not disable connector)
raise UnexpectedValidationError(
f"Unexpected error during blob storage settings validation: {e}"
)
if __name__ == "__main__":
credentials_dict = {

View File

@@ -9,10 +9,10 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.bookstack.client import BookStackApiClient
from onyx.connectors.bookstack.client import BookStackClientRequestFailedError
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch

View File

@@ -4,27 +4,18 @@ from datetime import timezone
from typing import Any
from urllib.parse import quote
from requests.exceptions import HTTPError
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.onyx_confluence import attachment_to_content
from onyx.connectors.confluence.onyx_confluence import (
extract_text_from_confluence_html,
)
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import attachment_to_content
from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import LoadConnector
@@ -86,9 +77,7 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
)
class ConfluenceConnector(
LoadConnector, PollConnector, SlimConnector, CredentialsConnector
):
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
wiki_base: str,
@@ -107,6 +96,7 @@ class ConfluenceConnector(
) -> None:
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self._confluence_client: OnyxConfluence | None = None
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
@@ -141,19 +131,6 @@ class ConfluenceConnector(
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
self.credentials_provider: CredentialsProviderInterface | None = None
self.probe_kwargs = {
"max_backoff_retries": 6,
"max_backoff_seconds": 10,
}
self.final_kwargs = {
"max_backoff_retries": 10,
"max_backoff_seconds": 60,
}
self._confluence_client: OnyxConfluence | None = None
@property
def confluence_client(self) -> OnyxConfluence:
@@ -161,22 +138,15 @@ class ConfluenceConnector(
raise ConnectorMissingCredentialError("Confluence")
return self._confluence_client
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
self.credentials_provider = credentials_provider
# raises exception if there's a problem
confluence_client = OnyxConfluence(
self.is_cloud, self.wiki_base, credentials_provider
)
confluence_client._probe_connection(**self.probe_kwargs)
confluence_client._initialize_connection(**self.final_kwargs)
self._confluence_client = confluence_client
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
raise NotImplementedError("Use set_credentials_provider with this connector.")
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
# for a list of other hidden constructor args
self._confluence_client = build_confluence_client(
credentials=credentials,
is_cloud=self.is_cloud,
wiki_base=self.wiki_base,
)
return None
def _construct_page_query(
self,
@@ -226,17 +196,12 @@ class ConfluenceConnector(
return comment_string
def _convert_object_to_document(
self,
confluence_object: dict[str, Any],
parent_content_id: str | None = None,
self, confluence_object: dict[str, Any]
) -> Document | None:
"""
Takes in a confluence object, extracts all metadata, and converts it into a document.
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
parent_content_id: if the object is an attachment, specifies the content id that
the attachment is attached to
"""
# The url and the id are the same
object_url = build_confluence_document_id(
@@ -255,9 +220,7 @@ class ConfluenceConnector(
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
confluence_client=self.confluence_client,
attachment=confluence_object,
parent_content_id=parent_content_id,
confluence_client=self.confluence_client, attachment=confluence_object
)
if object_text is None:
@@ -333,7 +296,7 @@ class ConfluenceConnector(
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
doc = self._convert_object_to_document(attachment, confluence_page_id)
doc = self._convert_object_to_document(attachment)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
@@ -434,33 +397,3 @@ class ConfluenceConnector(
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list
def validate_connector_settings(self) -> None:
if self._confluence_client is None:
raise ConnectorMissingCredentialError("Confluence credentials not loaded.")
try:
spaces = self._confluence_client.get_all_spaces(limit=1)
except HTTPError as e:
status_code = e.response.status_code if e.response else None
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Confluence credentials (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Insufficient permissions to access Confluence resources (HTTP 403)."
)
raise UnexpectedValidationError(
f"Unexpected Confluence error (status={status_code}): {e}"
)
except Exception as e:
raise UnexpectedValidationError(
f"Unexpected error while validating Confluence settings: {e}"
)
if not spaces or not spaces.get("results"):
raise ConnectorValidationError(
"No Confluence spaces found. Either your credentials lack permissions, or "
"there truly are no spaces in this Confluence instance."
)

View File

@@ -1,37 +1,16 @@
import io
import json
import math
import time
from collections.abc import Callable
from collections.abc import Iterator
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import TypeVar
from urllib.parse import quote
import bs4
from atlassian import Confluence # type:ignore
from pydantic import BaseModel
from redis import Redis
from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -40,14 +19,12 @@ logger = setup_logger()
F = TypeVar("F", bound=Callable[..., Any])
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
# https://jira.atlassian.com/browse/CONFCLOUD-76433
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
_REPLACEMENT_EXPANSIONS = "body.view.value"
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class ConfluenceRateLimitError(Exception):
pass
@@ -63,352 +40,127 @@ class ConfluenceUser(BaseModel):
type: str
def _handle_http_error(e: HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
_DEFAULT_PAGINATION_LIMIT = 1000
_MINIMUM_PAGINATION_LIMIT = 50
class OnyxConfluence:
class OnyxConfluence(Confluence):
"""
This is a custom Confluence class that:
A. overrides the default Confluence class to add a custom CQL method.
B.
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
This is necessary because the default Confluence class does not properly support cql expansions.
All methods are automatically wrapped with handle_confluence_rate_limit.
"""
CREDENTIAL_PREFIX = "connector:confluence:credential"
CREDENTIAL_TTL = 300 # 5 min
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
self._wrap_methods()
def __init__(
self,
is_cloud: bool,
url: str,
credentials_provider: CredentialsProviderInterface,
) -> None:
self._is_cloud = is_cloud
self._url = url.rstrip("/")
self._credentials_provider = credentials_provider
self.redis_client: Redis | None = None
self.static_credentials: dict[str, Any] | None = None
if self._credentials_provider.is_dynamic():
self.redis_client = get_redis_client(
tenant_id=credentials_provider.get_tenant_id()
)
else:
self.static_credentials = self._credentials_provider.get_credentials()
self._confluence = Confluence(url)
self.credential_key: str = (
self.CREDENTIAL_PREFIX
+ f":credential_{self._credentials_provider.get_provider_key()}"
)
self._kwargs: Any = None
self.shared_base_kwargs = {
"api_version": "cloud" if is_cloud else "latest",
"backoff_and_retry": True,
"cloud": is_cloud,
}
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
"""credential_json - the current json credentials
Returns a tuple
1. The up to date credentials
2. True if the credentials were updated
This method is intended to be used within a distributed lock.
Lock, call this, update credentials if the tokens were refreshed, then release
def _wrap_methods(self) -> None:
"""
# static credentials are preloaded, so no locking/redis required
if self.static_credentials:
return self.static_credentials, False
if not self.redis_client:
raise RuntimeError("self.redis_client is None")
# dynamic credentials need locking
# check redis first, then fallback to the DB
credential_raw = self.redis_client.get(self.credential_key)
if credential_raw is not None:
credential_bytes = cast(bytes, credential_raw)
credential_str = credential_bytes.decode("utf-8")
credential_json: dict[str, Any] = json.loads(credential_str)
else:
credential_json = self._credentials_provider.get_credentials()
if "confluence_refresh_token" not in credential_json:
# static credentials ... cache them permanently and return
self.static_credentials = credential_json
return credential_json, False
# check if we should refresh tokens. we're deciding to refresh halfway
# to expiration
now = datetime.now(timezone.utc)
created_at = datetime.fromisoformat(credential_json["created_at"])
expires_in: int = credential_json["expires_in"]
renew_at = created_at + timedelta(seconds=expires_in // 2)
if now <= renew_at:
# cached/current credentials are reasonably up to date
return credential_json, False
# we need to refresh
logger.info("Renewing Confluence Cloud credentials...")
new_credentials = confluence_refresh_tokens(
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID,
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET,
credential_json["cloud_id"],
credential_json["confluence_refresh_token"],
)
# store the new credentials to redis and to the db thru the provider
# redis: we use a 5 min TTL because we are given a 10 minute grace period
# when keys are rotated. it's easier to expire the cached credentials
# reasonably frequently rather than trying to handle strong synchronization
# between the db and redis everywhere the credentials might be updated
new_credential_str = json.dumps(new_credentials)
self.redis_client.set(
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
)
self._credentials_provider.set_credentials(new_credentials)
return new_credentials, True
@staticmethod
def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]:
oauth2_dict: dict[str, Any] = {}
if "confluence_refresh_token" in credentials:
oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
oauth2_dict["token"] = {}
oauth2_dict["token"]["access_token"] = credentials[
"confluence_access_token"
]
return oauth2_dict
def _probe_connection(
self,
**kwargs: Any,
) -> None:
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
with self._credentials_provider:
credentials, _ = self._renew_credentials()
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Probing Confluence with OAuth Access Token.")
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
credentials
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
wrap it with handle_confluence_rate_limit.
"""
for attr_name in dir(self):
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
setattr(
self,
attr_name,
handle_confluence_rate_limit(getattr(self, attr_name)),
)
url = (
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
)
confluence_client_with_minimal_retries = Confluence(
url=url, oauth2=oauth2_dict, **merged_kwargs
)
else:
logger.info("Probing Confluence with Personal Access Token.")
url = self._url
if self._is_cloud:
confluence_client_with_minimal_retries = Confluence(
url=url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**merged_kwargs,
)
else:
confluence_client_with_minimal_retries = Confluence(
url=url,
token=credentials["confluence_access_token"],
**merged_kwargs,
)
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
if not spaces:
raise RuntimeError(
f"No spaces found at {url}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
logger.info("Confluence probe succeeded.")
def _initialize_connection(
self,
**kwargs: Any,
) -> None:
"""Called externally to init the connection in a thread safe manner."""
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
with self._credentials_provider:
credentials, _ = self._renew_credentials()
self._confluence = self._initialize_connection_helper(
credentials, **merged_kwargs
)
self._kwargs = merged_kwargs
def _initialize_connection_helper(
self,
credentials: dict[str, Any],
**kwargs: Any,
) -> Confluence:
"""Called internally to init the connection. Distributed locking
to prevent multiple threads from modifying the credentials
must be handled around this function."""
confluence = None
# probe connection with direct client, no retries
if "confluence_refresh_token" in credentials:
logger.info("Connecting to Confluence Cloud with OAuth Access Token.")
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials)
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
else:
logger.info("Connecting to Confluence with Personal Access Token.")
if self._is_cloud:
confluence = Confluence(
url=self._url,
username=credentials["confluence_username"],
password=credentials["confluence_access_token"],
**kwargs,
)
else:
confluence = Confluence(
url=self._url,
token=credentials["confluence_access_token"],
**kwargs,
)
return confluence
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def _make_rate_limited_confluence_method(
self, name: str, credential_provider: CredentialsProviderInterface | None
) -> Callable[..., Any]:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
try:
if credential_provider:
with credential_provider:
credentials, renewed = self._renew_credentials()
if renewed:
self._confluence = self._initialize_connection_helper(
credentials, **self._kwargs
)
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return attr(*args, **kwargs)
else:
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
return attr(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return wrapped_call
# def _wrap_methods(self) -> None:
# """
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
# wrap it with handle_confluence_rate_limit.
# """
# for attr_name in dir(self):
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
# setattr(
# self,
# attr_name,
# handle_confluence_rate_limit(getattr(self, attr_name)),
# )
# def _ensure_token_valid(self) -> None:
# if self._token_is_expired():
# self._refresh_token()
# # Re-init the Confluence client with the originally stored args
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
def __getattr__(self, name: str) -> Any:
"""Dynamically intercept attribute/method access."""
attr = getattr(self._confluence, name, None)
if attr is None:
# The underlying Confluence client doesn't have this attribute
raise AttributeError(
f"'{type(self).__name__}' object has no attribute '{name}'"
)
# If it's not a method, just return it after ensuring token validity
if not callable(attr):
return attr
# skip methods that start with "_"
if name.startswith("_"):
return attr
# wrap the method with our retry handler
rate_limited_method: Callable[
..., Any
] = self._make_rate_limited_confluence_method(name, self._credentials_provider)
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
return rate_limited_method(*args, **kwargs)
return wrapped_method
def _paginate_url(
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
self, url_suffix: str, limit: int | None = None
) -> Iterator[dict[str, Any]]:
"""
This will paginate through the top level query.
@@ -483,41 +235,9 @@ class OnyxConfluence:
raise e
# yield the results individually
results = cast(list[dict[str, Any]], next_response.get("results", []))
yield from results
yield from next_response.get("results", [])
old_url_suffix = url_suffix
url_suffix = cast(str, next_response.get("_links", {}).get("next", ""))
# make sure we don't update the start by more than the amount
# of results we were able to retrieve. The Confluence API has a
# weird behavior where if you pass in a limit that is too large for
# the configured server, it will artificially limit the amount of
# results returned BUT will not apply this to the start parameter.
# This will cause us to miss results.
if url_suffix and "start" in url_suffix:
new_start = get_start_param_from_url(url_suffix)
previous_start = get_start_param_from_url(old_url_suffix)
if new_start - previous_start > len(results):
logger.warning(
f"Start was updated by more than the amount of results "
f"retrieved. This is a bug with Confluence. Start: {new_start}, "
f"Previous Start: {previous_start}, Len Results: {len(results)}."
)
# Update the url_suffix to use the adjusted start
adjusted_start = previous_start + len(results)
url_suffix = update_param_in_path(
url_suffix, "start", str(adjusted_start)
)
# some APIs don't properly paginate, so we need to manually update the `start` param
if auto_paginate and len(results) > 0:
previous_start = get_start_param_from_url(old_url_suffix)
updated_start = previous_start + len(results)
url_suffix = update_param_in_path(
old_url_suffix, "start", str(updated_start)
)
url_suffix = next_response.get("_links", {}).get("next")
def paginated_cql_retrieval(
self,
@@ -577,9 +297,7 @@ class OnyxConfluence:
url = "rest/api/search/user"
expand_string = f"&expand={expand}" if expand else ""
url += f"?cql={cql}{expand_string}"
# endpoint doesn't properly paginate, so we need to manually update the `start` param
# thus the auto_paginate flag
for user_result in self._paginate_url(url, limit, auto_paginate=True):
for user_result in self._paginate_url(url, limit):
# Example response:
# {
# 'user': {
@@ -752,212 +470,59 @@ class OnyxConfluence:
return response
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
def _validate_connector_configuration(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> None:
# test connection with direct client, no retries
confluence_client_with_minimal_retries = Confluence(
api_version="cloud" if is_cloud else "latest",
url=wiki_base.rstrip("/"),
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=6,
max_backoff_seconds=10,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
return extracted_text
# uncomment the following for testing
# the following is an attempt to retrieve the user's timezone
# Unfornately, all data is returned in UTC regardless of the user's time zone
# even tho CQL parses incoming times based on the user's time zone
# space_key = spaces["results"][0]["key"]
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
if not spaces:
raise RuntimeError(
f"No spaces found at {wiki_base}! "
"Check your credentials and wiki_base and make sure "
"is_cloud is set correctly."
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
def build_confluence_client(
credentials: dict[str, Any],
is_cloud: bool,
wiki_base: str,
) -> OnyxConfluence:
_validate_connector_configuration(
credentials=credentials,
is_cloud=is_cloud,
wiki_base=wiki_base,
)
return OnyxConfluence(
api_version="cloud" if is_cloud else "latest",
# Remove trailing slash from wiki_base if present
url=wiki_base.rstrip("/"),
# passing in username causes issues for Confluence data center
username=credentials["confluence_username"] if is_cloud else None,
password=credentials["confluence_access_token"] if is_cloud else None,
token=credentials["confluence_access_token"] if not is_cloud else None,
backoff_and_retry=True,
max_backoff_retries=10,
max_backoff_seconds=60,
cloud=is_cloud,
)

View File

@@ -1,38 +1,182 @@
import math
import time
from collections.abc import Callable
import io
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
from typing import TypeVar
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
import requests
from pydantic import BaseModel
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.onyx_confluence import (
OnyxConfluence,
)
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
pass
logger = setup_logger()
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
_USER_EMAIL_CACHE: dict[str, str | None] = {}
class TokenResponse(BaseModel):
access_token: str
expires_in: int
token_type: str
refresh_token: str
scope: str
def get_user_email_from_username__server(
confluence_client: OnyxConfluence, user_name: str
) -> str | None:
global _USER_EMAIL_CACHE
if _USER_EMAIL_CACHE.get(user_name) is None:
try:
response = confluence_client.get_mobile_parameters(user_name)
email = response.get("email")
except Exception:
logger.warning(f"failed to get confluence email for {user_name}")
# For now, we'll just return None and log a warning. This means
# we will keep retrying to get the email every group sync.
email = None
# We may want to just return a string that indicates failure so we dont
# keep retrying
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
_USER_EMAIL_CACHE[user_name] = email
return _USER_EMAIL_CACHE[user_name]
_USER_NOT_FOUND = "Unknown Confluence User"
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
"""Get Confluence Display Name based on the account-id or userkey value
Args:
user_id (str): The user id (i.e: the account-id or userkey)
confluence_client (Confluence): The Confluence Client
Returns:
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
"""
global _USER_ID_TO_DISPLAY_NAME_CACHE
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
try:
result = confluence_client.get_user_details_by_userkey(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
if not found_display_name:
try:
result = confluence_client.get_user_details_by_accountid(user_id)
found_display_name = result.get("displayName")
except Exception:
found_display_name = None
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],
fetched_titles: set[str],
) -> str:
"""Parse a Confluence html page and replace the 'user Id' by the real
User Display Name
Args:
confluence_object (dict): The confluence object as a dict
confluence_client (Confluence): Confluence client
fetched_titles (set[str]): The titles of the pages that have already been fetched
Returns:
str: loaded and formated Confluence page
"""
body = confluence_object["body"]
object_html = body.get("storage", body.get("view", {})).get("value")
soup = bs4.BeautifulSoup(object_html, "html.parser")
for user in soup.findAll("ri:user"):
user_id = (
user.attrs["ri:account-id"]
if "ri:account-id" in user.attrs
else user.get("ri:userkey")
)
if not user_id:
logger.warning(
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
)
continue
# Include @ sign for tagging, more clear for LLM
user.replaceWith("@" + _get_user(confluence_client, user_id))
for html_page_reference in soup.findAll("ac:structured-macro"):
# Here, we only want to process page within page macros
if html_page_reference.attrs.get("ac:name") != "include":
continue
page_data = html_page_reference.find("ri:page")
if not page_data:
logger.warning(
f"Skipping retrieval of {html_page_reference} because because page data is missing"
)
continue
page_title = page_data.attrs.get("ri:content-title")
if not page_title:
# only fetch pages that have a title
logger.warning(
f"Skipping retrieval of {html_page_reference} because it has no title"
)
continue
if page_title in fetched_titles:
# prevent recursive fetching of pages
logger.debug(f"Skipping {page_title} because it has already been fetched")
continue
fetched_titles.add(page_title)
# Wrap this in a try-except because there are some pages that might not exist
try:
page_query = f"type=page and title='{quote(page_title)}'"
page_contents: dict[str, Any] | None = None
# Confluence enforces title uniqueness, so we should only get one result here
for page in confluence_client.paginated_cql_retrieval(
cql=page_query,
expand="body.storage.value",
limit=1,
):
page_contents = page
break
except Exception as e:
logger.warning(
f"Error getting page contents for object {confluence_object}: {e}"
)
continue
if not page_contents:
continue
text_from_page = extract_text_from_confluence_html(
confluence_client=confluence_client,
confluence_object=page_contents,
fetched_titles=fetched_titles,
)
html_page_reference.replaceWith(text_from_page)
for html_link_body in soup.findAll("ac:link-body"):
# This extracts the text from inline links in the page so they can be
# represented in the document text as plain text
try:
text_from_link = html_link_body.text
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
except Exception as e:
logger.warning(f"Error processing ac:link-body: {e}")
return format_document_soup(soup)
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
@@ -46,6 +190,49 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
]
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def build_confluence_document_id(
base_url: str, content_url: str, is_cloud: bool
) -> str:
@@ -92,163 +279,3 @@ def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime_object.astimezone(timezone.utc)
return datetime_object
def confluence_refresh_tokens(
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
) -> dict[str, Any]:
# rotate the refresh and access token
# Note that access tokens are only good for an hour in confluence cloud,
# so we're going to have problems if the connector runs for longer
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair
response = requests.post(
CONFLUENCE_OAUTH_TOKEN_URL,
headers={"Content-Type": "application/x-www-form-urlencoded"},
data={
"grant_type": "refresh_token",
"client_id": client_id,
"client_secret": client_secret,
"refresh_token": refresh_token,
},
)
try:
token_response = TokenResponse.model_validate_json(response.text)
except Exception:
raise RuntimeError("Confluence Cloud token refresh failed.")
now = datetime.now(timezone.utc)
expires_at = now + timedelta(seconds=token_response.expires_in)
new_credentials: dict[str, Any] = {}
new_credentials["confluence_access_token"] = token_response.access_token
new_credentials["confluence_refresh_token"] = token_response.refresh_token
new_credentials["created_at"] = now.isoformat()
new_credentials["expires_at"] = expires_at.isoformat()
new_credentials["expires_in"] = token_response.expires_in
new_credentials["scope"] = token_response.scope
new_credentials["cloud_id"] = cloud_id
return new_credentials
F = TypeVar("F", bound=Callable[..., Any])
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
# this uses the native rate limiting option provided by the
# confluence client and otherwise applies a simpler set of error handling
def handle_confluence_rate_limit(confluence_call: F) -> F:
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
MAX_RETRIES = 5
TIMEOUT = 600
timeout_at = time.monotonic() + TIMEOUT
for attempt in range(MAX_RETRIES):
if time.monotonic() > timeout_at:
raise TimeoutError(
f"Confluence call attempts took longer than {TIMEOUT} seconds."
)
try:
# we're relying more on the client to rate limit itself
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except requests.HTTPError as e:
delay_until = _handle_http_error(e, attempt)
logger.warning(
f"HTTPError in confluence call. "
f"Retrying in {delay_until} seconds..."
)
while time.monotonic() < delay_until:
# in the future, check a signal here to exit
time.sleep(1)
except AttributeError as e:
# Some error within the Confluence library, unclear why it fails.
# Users reported it to be intermittent, so just retry
if attempt == MAX_RETRIES - 1:
raise e
logger.exception(
"Confluence Client raised an AttributeError. Retrying..."
)
time.sleep(5)
return cast(F, wrapped_call)
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
BACKOFF = 2
# Check if the response or headers are None to avoid potential AttributeError
if e.response is None or e.response.headers is None:
logger.warning("HTTPError with `None` as response or as headers")
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
):
raise e
retry_after = None
retry_after_header = e.response.headers.get("Retry-After")
if retry_after_header is not None:
try:
retry_after = int(retry_after_header)
if retry_after > MAX_DELAY:
logger.warning(
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
)
retry_after = MAX_DELAY
if retry_after < MIN_DELAY:
retry_after = MIN_DELAY
except ValueError:
pass
if retry_after is not None:
logger.warning(
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
)
delay = retry_after
else:
logger.warning(
"Rate limiting without retry header. Retrying with exponential backoff..."
)
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
delay_until = math.ceil(time.monotonic() + delay)
return delay_until
def get_single_param_from_url(url: str, param: str) -> str | None:
"""Get a parameter from a url"""
parsed_url = urlparse(url)
return parse_qs(parsed_url.query).get(param, [None])[0]
def get_start_param_from_url(url: str) -> int:
"""Get the start parameter from a url"""
start_str = get_single_param_from_url(url, "start")
if start_str is None:
return 0
return int(start_str)
def update_param_in_path(path: str, param: str, value: str) -> str:
"""Update a parameter in a path. Path should look something like:
/api/rest/users?start=0&limit=10
"""
parsed_url = urlparse(path)
query_params = parse_qs(parsed_url.query)
query_params[param] = [value]
return (
path.split("?")[0]
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)

View File

@@ -1,135 +0,0 @@
import uuid
from types import TracebackType
from typing import Any
from redis.lock import Lock as RedisLock
from sqlalchemy import select
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.db.engine import get_session_with_tenant
from onyx.db.models import Credential
from onyx.redis.redis_pool import get_redis_client
class OnyxDBCredentialsProvider(
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
):
"""Implementation to allow the connector to callback and update credentials in the db.
Required in cases where credentials can rotate while the connector is running.
"""
LOCK_TTL = 900 # TTL of the lock
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_id = credential_id
self.redis_client = get_redis_client(tenant_id=tenant_id)
# lock used to prevent overlapping renewal of credentials
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
def __enter__(self) -> "OnyxDBCredentialsProvider":
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
if not acquired:
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
"""Release the lock when exiting the context."""
if self._lock and self._lock.owned():
self._lock.release()
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return str(self._credential_id)
def get_credentials(self) -> dict[str, Any]:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
credential = db_session.execute(
select(Credential).where(Credential.id == self._credential_id)
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
return credential.credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
try:
credential = db_session.execute(
select(Credential)
.where(Credential.id == self._credential_id)
.with_for_update()
).scalar_one()
if credential is None:
raise ValueError(
f"No credential found: credential={self._credential_id}"
)
credential.credential_json = credential_json
db_session.commit()
except Exception:
db_session.rollback()
raise
def is_dynamic(self) -> bool:
return True
class OnyxStaticCredentialsProvider(
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
):
"""Implementation (a very simple one!) to handle static credentials."""
def __init__(
self,
tenant_id: str | None,
connector_name: str,
credential_json: dict[str, Any],
):
self._tenant_id = tenant_id
self._connector_name = connector_name
self._credential_json = credential_json
self._provider_key = str(uuid.uuid4())
def __enter__(self) -> "OnyxStaticCredentialsProvider":
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
pass
def get_tenant_id(self) -> str | None:
return self._tenant_id
def get_provider_key(self) -> str:
return self._provider_key
def get_credentials(self) -> dict[str, Any]:
return self._credential_json
def set_credentials(self, credential_json: dict[str, Any]) -> None:
self._credential_json = credential_json
def is_dynamic(self) -> bool:
return False

View File

@@ -10,10 +10,10 @@ from dropbox.files import FolderMetadata # type:ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialInvalidError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialInvalidError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch

View File

@@ -1,52 +0,0 @@
class ValidationError(Exception):
"""General exception for validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class ConnectorValidationError(ValidationError):
"""General exception for connector validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class UnexpectedValidationError(ValidationError):
"""Raised when an unexpected error occurs during connector validation.
Unexpected errors don't necessarily mean the credential is invalid,
but rather that there was an error during the validation process
or we encountered a currently unhandled error case.
Currently, unexpected validation errors are defined as transient and should not be
used to disable the connector.
"""
def __init__(self, message: str = "Unexpected error during connector validation"):
super().__init__(message)
class CredentialInvalidError(ConnectorValidationError):
"""Raised when a connector's credential is invalid."""
def __init__(self, message: str = "Credential is invalid"):
super().__init__(message)
class CredentialExpiredError(ConnectorValidationError):
"""Raised when a connector's credential is expired."""
def __init__(self, message: str = "Credential has expired"):
super().__init__(message)
class InsufficientPermissionsError(ConnectorValidationError):
"""Raised when the credential does not have sufficient API permissions."""
def __init__(
self, message: str = "Insufficient permissions for the requested operation"
):
super().__init__(message)

View File

@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import DocumentSourceRequiringTenantContext
from onyx.connectors.airtable.airtable_connector import AirtableConnector
from onyx.connectors.asana.connector import AsanaConnector
from onyx.connectors.axero.connector import AxeroConnector
@@ -12,13 +13,11 @@ from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.bookstack.connector import BookstackConnector
from onyx.connectors.clickup.connector import ClickupConnector
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.discord.connector import DiscordConnector
from onyx.connectors.discourse.connector import DiscourseConnector
from onyx.connectors.document360.connector import Document360Connector
from onyx.connectors.dropbox.connector import DropboxConnector
from onyx.connectors.egnyte.connector import EgnyteConnector
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.fireflies.connector import FirefliesConnector
from onyx.connectors.freshdesk.connector import FreshdeskConnector
@@ -33,7 +32,7 @@ from onyx.connectors.guru.connector import GuruConnector
from onyx.connectors.hubspot.connector import HubSpotConnector
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import EventConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -57,9 +56,9 @@ from onyx.connectors.zendesk.connector import ZendeskConnector
from onyx.connectors.zulip.connector import ZulipConnector
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import backend_update_credential_json
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.models import Credential
from shared_configs.contextvars import get_current_tenant_id
from onyx.db.models import User
class ConnectorMissingException(Exception):
@@ -166,21 +165,18 @@ def instantiate_connector(
input_type: InputType,
connector_specific_config: dict[str, Any],
credential: Credential,
tenant_id: str | None = None,
) -> BaseConnector:
connector_class = identify_connector_class(source, input_type)
if source in DocumentSourceRequiringTenantContext:
connector_specific_config["tenant_id"] = tenant_id
connector = connector_class(**connector_specific_config)
new_credentials = connector.load_credentials(credential.credential_json)
if isinstance(connector, CredentialsConnector):
provider = OnyxDBCredentialsProvider(
get_current_tenant_id(), str(source), credential.id
)
connector.set_credentials_provider(provider)
else:
new_credentials = connector.load_credentials(credential.credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)
return connector
@@ -189,16 +185,19 @@ def validate_ccpair_for_user(
connector_id: int,
credential_id: int,
db_session: Session,
enforce_creation: bool = True,
) -> bool:
user: User | None,
tenant_id: str | None,
) -> None:
if INTEGRATION_TESTS_MODE:
return True
return
# Validate the connector settings
connector = fetch_connector_by_id(connector_id, db_session)
credential = fetch_credential_by_id(
credential = fetch_credential_by_id_for_user(
credential_id,
user,
db_session,
get_editable=False,
)
if not connector:
@@ -208,7 +207,7 @@ def validate_ccpair_for_user(
connector.source == DocumentSource.INGESTION_API
or connector.source == DocumentSource.MOCK_CONNECTOR
):
return True
return
if not credential:
raise ValueError("Credential not found")
@@ -220,14 +219,9 @@ def validate_ccpair_for_user(
input_type=connector.input_type,
connector_specific_config=connector.connector_specific_config,
credential=credential,
tenant_id=tenant_id,
)
except ConnectorValidationError as e:
raise e
except Exception as e:
if enforce_creation:
raise ConnectorValidationError(str(e))
else:
return False
raise ConnectorValidationError(str(e))
runnable_connector.validate_connector_settings()
return True

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.engine import get_session_with_tenant
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -27,6 +27,8 @@ from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -163,10 +165,12 @@ class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.tenant_id = tenant_id
self.pdf_pass: str | None = None
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
@@ -175,8 +179,9 @@ class LocalFileConnector(LoadConnector):
def load_from_state(self) -> GenerateDocumentsOutput:
documents: list[Document] = []
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
with get_session_with_current_tenant() as db_session:
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files = _read_files_and_metadata(
@@ -198,6 +203,8 @@ class LocalFileConnector(LoadConnector):
if documents:
yield documents
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
if __name__ == "__main__":
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])

View File

@@ -17,14 +17,14 @@ from github.PullRequest import PullRequest
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
@@ -124,7 +124,7 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repo_name: str | None = None,
repo_name: str,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
@@ -162,81 +162,53 @@ class GithubConnector(LoadConnector, PollConnector):
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
# Try to get organization first
try:
org = github_client.get_organization(self.repo_owner)
return list(org.get_repos())
except GithubException:
# If not an org, try as a user
user = github_client.get_user(self.repo_owner)
return list(user.get_repos())
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_all_repos(github_client, attempt_num + 1)
def _fetch_from_github(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
repo = self._get_github_repo(self.github_client)
for repo in repos:
if self.include_prs:
logger.info(f"Fetching PRs for repo: {repo.name}")
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
if self.include_prs:
pull_requests = repo.get_pulls(
state=self.state_filter, sort="updated", direction="desc"
)
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
break
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
for pr_batch in _batch_github_objects(
pull_requests, self.github_client, self.batch_size
):
doc_batch: list[Document] = []
for pr in pr_batch:
if start is not None and pr.updated_at < start:
yield doc_batch
return
if end is not None and pr.updated_at > end:
continue
doc_batch.append(_convert_pr_to_document(cast(PullRequest, pr)))
yield doc_batch
if self.include_issues:
logger.info(f"Fetching issues for repo: {repo.name}")
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
if self.include_issues:
issues = repo.get_issues(
state=self.state_filter, sort="updated", direction="desc"
)
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
break
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
for issue_batch in _batch_github_objects(
issues, self.github_client, self.batch_size
):
doc_batch = []
for issue in issue_batch:
issue = cast(Issue, issue)
if start is not None and issue.updated_at < start:
yield doc_batch
return
if end is not None and issue.updated_at > end:
continue
if issue.pull_request is not None:
# PRs are handled separately
continue
doc_batch.append(_convert_issue_to_document(issue))
yield doc_batch
def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_github()
@@ -262,29 +234,19 @@ class GithubConnector(LoadConnector, PollConnector):
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub credentials not loaded.")
if not self.repo_owner:
if not self.repo_owner or not self.repo_name:
raise ConnectorValidationError(
"Invalid connector settings: 'repo_owner' must be provided."
"Invalid connector settings: 'repo_owner' and 'repo_name' must be provided."
)
try:
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
org = self.github_client.get_organization(self.repo_owner)
org.get_repos().totalCount # Just check if we can access repos
except GithubException:
# If not an org, try as a user
user = self.github_client.get_user(self.repo_owner)
user.get_repos().totalCount # Just check if we can access repos
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
except RateLimitExceededException:
raise UnexpectedValidationError(
raise UnexpectedError(
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
)
@@ -298,14 +260,9 @@ class GithubConnector(LoadConnector, PollConnector):
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
)
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"

View File

@@ -305,7 +305,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
userId=user_email,
fields=THREAD_FIELDS,
id=thread["id"],
continue_on_404_or_403=True,
)
# full_threads is an iterator containing a single thread
# so we need to convert it to a list and grab the first element
@@ -337,7 +336,6 @@ class GmailConnector(LoadConnector, PollConnector, SlimConnector):
userId=user_email,
fields=THREAD_LIST_FIELDS,
q=query,
continue_on_404_or_403=True,
):
doc_batch.append(
SlimDocument(

View File

@@ -13,9 +13,6 @@ from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.google_drive.doc_conversion import build_slim_document
from onyx.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
@@ -45,7 +42,6 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -141,7 +137,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
raise ConnectorValidationError(
raise ValueError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
@@ -155,7 +151,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
and not my_drive_emails
and not shared_drive_urls
):
raise ConnectorValidationError(
raise ValueError(
"Nothing to index. Please specify at least one of the following: "
"include_shared_drives, include_my_drives, include_files_shared_with_me, "
"shared_folder_urls, or my_drive_emails"
@@ -613,50 +609,3 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if MISSING_SCOPES_ERROR_STR in str(e):
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
def validate_connector_settings(self) -> None:
if self._creds is None:
raise ConnectorMissingCredentialError(
"Google Drive credentials not loaded."
)
if self._primary_admin_email is None:
raise ConnectorValidationError(
"Primary admin email not found in credentials. "
"Ensure DB_CREDENTIALS_PRIMARY_ADMIN_KEY is set."
)
try:
drive_service = get_drive_service(self._creds, self._primary_admin_email)
drive_service.files().list(pageSize=1, fields="files(id)").execute()
if isinstance(self._creds, ServiceAccountCredentials):
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
status_code = e.resp.status if e.resp else None
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Google Drive credentials (401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Google Drive app lacks required permissions (403). "
"Please ensure the necessary scopes are granted and Drive "
"apps are enabled."
)
else:
raise ConnectorValidationError(
f"Unexpected Google Drive error (status={status_code}): {e}"
)
except Exception as e:
# Check for scope-related hints from the error message
if MISSING_SCOPES_ERROR_STR in str(e):
raise InsufficientPermissionsError(
"Google Drive credentials are missing required scopes. "
f"{ONYX_SCOPE_INSTRUCTIONS}"
)
raise ConnectorValidationError(
f"Unexpected error during Google Drive validation: {e}"
)

View File

@@ -1,9 +1,7 @@
import io
from datetime import datetime
from datetime import timezone
from tempfile import NamedTemporaryFile
import openpyxl # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
@@ -45,15 +43,12 @@ def _extract_sections_basic(
) -> list[Section]:
mime_type = file["mimeType"]
link = file["webViewLink"]
supported_file_types = set(item.value for item in GDriveMimeType)
if mime_type not in supported_file_types:
if mime_type not in set(item.value for item in GDriveMimeType):
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
# ---------------------------
# Google Sheets extraction
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
@@ -114,53 +109,7 @@ def _extract_sections_basic(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
# ---------------------------
# Microsoft Excel (.xlsx or .xls) extraction branch
elif mime_type in [
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
]:
try:
response = service.files().get_media(fileId=file["id"]).execute()
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
tmp.write(response)
tmp_path = tmp.name
section_separator = "\n\n"
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
# Work similarly to the xlsx_to_text function used for file connector
# but returns Sections instead of a string
sections = [
Section(
link=link,
text=(
f"Sheet: {sheet.title}\n\n"
+ section_separator.join(
",".join(map(str, row))
for row in sheet.iter_rows(
min_row=1, values_only=True
)
if row
)
),
)
for sheet in workbook.worksheets
]
return sections
except Exception as e:
logger.warning(
f"Error extracting data from Excel file '{file['name']}': {e}"
)
return [
Section(link=link, text="Error extracting data from Excel file")
]
# ---------------------------
# Export for Google Docs, PPT, and fallback for spreadsheets
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
@@ -179,8 +128,6 @@ def _extract_sections_basic(
)
return [Section(link=link, text=text)]
# ---------------------------
# Plain text and Markdown files
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
@@ -194,8 +141,6 @@ def _extract_sections_basic(
.decode("utf-8"),
)
]
# ---------------------------
# Word, PowerPoint, PDF files
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
@@ -225,11 +170,7 @@ def _extract_sections_basic(
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
error_message = f"Unsupported file type: {mime_type}"
logger.error(error_message)
raise ValueError(error_message)
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]

View File

@@ -5,10 +5,6 @@ from typing import Any
class GDriveMimeType(str, Enum):
DOC = "application/vnd.google-apps.document"
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
SPREADSHEET_OPEN_FORMAT = (
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
)
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
PDF = "application/pdf"
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
PPT = "application/vnd.google-apps.presentation"

View File

@@ -87,18 +87,16 @@ class HubSpotConnector(LoadConnector, PollConnector):
contact = api_client.crm.contacts.basic_api.get_by_id(
contact_id=contact.id
)
email = contact.properties.get("email")
if email is not None:
associated_emails.append(email)
associated_emails.append(contact.properties["email"])
if notes:
for note in notes.results:
note = api_client.crm.objects.notes.basic_api.get_by_id(
note_id=note.id, properties=["content", "hs_body_preview"]
)
preview = note.properties.get("hs_body_preview")
if preview is not None:
associated_notes.append(preview)
if note.properties["hs_body_preview"] is None:
continue
associated_notes.append(note.properties["hs_body_preview"])
associated_emails_str = " ,".join(associated_emails)
associated_notes_str = " ".join(associated_notes)

View File

@@ -1,10 +1,7 @@
import abc
from collections.abc import Generator
from collections.abc import Iterator
from types import TracebackType
from typing import Any
from typing import Generic
from typing import TypeVar
from pydantic import BaseModel
@@ -114,69 +111,6 @@ class OAuthConnector(BaseConnector):
raise NotImplementedError
T = TypeVar("T", bound="CredentialsProviderInterface")
class CredentialsProviderInterface(abc.ABC, Generic[T]):
@abc.abstractmethod
def __enter__(self) -> T:
raise NotImplementedError
@abc.abstractmethod
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_value: BaseException | None,
traceback: TracebackType | None,
) -> None:
raise NotImplementedError
@abc.abstractmethod
def get_tenant_id(self) -> str | None:
raise NotImplementedError
@abc.abstractmethod
def get_provider_key(self) -> str:
"""a unique key that the connector can use to lock around a credential
that might be used simultaneously.
Will typically be the credential id, but can also just be something random
in cases when there is nothing to lock (aka static credentials)
"""
raise NotImplementedError
@abc.abstractmethod
def get_credentials(self) -> dict[str, Any]:
raise NotImplementedError
@abc.abstractmethod
def set_credentials(self, credential_json: dict[str, Any]) -> None:
raise NotImplementedError
@abc.abstractmethod
def is_dynamic(self) -> bool:
"""If dynamic, the credentials may change during usage ... maening the client
needs to use the locking features of the credentials provider to operate
correctly.
If static, the client can simply reference the credentials once and use them
through the entire indexing run.
"""
raise NotImplementedError
class CredentialsConnector(BaseConnector):
"""Implement this if the connector needs to be able to read and write credentials
on the fly. Typically used with shared credentials/tokens that might be renewed
at any time."""
@abc.abstractmethod
def set_credentials_provider(
self, credentials_provider: CredentialsProviderInterface
) -> None:
raise NotImplementedError
# Event driven
class EventConnector(BaseConnector):
@abc.abstractmethod
@@ -212,3 +146,46 @@ class CheckpointConnector(BaseConnector):
```
"""
raise NotImplementedError
class ConnectorValidationError(Exception):
"""General exception for connector validation errors."""
def __init__(self, message: str):
self.message = message
super().__init__(self.message)
class UnexpectedError(Exception):
"""Raised when an unexpected error occurs during connector validation.
Unexpected errors don't necessarily mean the credential is invalid,
but rather that there was an error during the validation process
or we encountered a currently unhandled error case.
"""
def __init__(self, message: str = "Unexpected error during connector validation"):
super().__init__(message)
class CredentialInvalidError(ConnectorValidationError):
"""Raised when a connector's credential is invalid."""
def __init__(self, message: str = "Credential is invalid"):
super().__init__(message)
class CredentialExpiredError(ConnectorValidationError):
"""Raised when a connector's credential is expired."""
def __init__(self, message: str = "Credential has expired"):
super().__init__(message)
class InsufficientPermissionsError(ConnectorValidationError):
"""Raised when the credential does not have sufficient API permissions."""
def __init__(
self, message: str = "Insufficient permissions for the requested operation"
):
super().__init__(message)

View File

@@ -16,11 +16,10 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rl_requests,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@@ -671,12 +670,12 @@ class NotionConnector(LoadConnector, PollConnector):
"Please try again later."
)
else:
raise UnexpectedValidationError(
raise Exception(
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
) from http_err
except Exception as exc:
raise UnexpectedValidationError(
raise Exception(
f"Unexpected error during Notion settings validation: {exc}"
)

View File

@@ -12,11 +12,11 @@ from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
@@ -29,6 +29,7 @@ from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
from onyx.connectors.onyx_jira.utils import build_jira_client
from onyx.connectors.onyx_jira.utils import build_jira_url
from onyx.connectors.onyx_jira.utils import extract_jira_project
from onyx.connectors.onyx_jira.utils import extract_text_from_adf
from onyx.connectors.onyx_jira.utils import get_comment_strs
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -159,8 +160,7 @@ def fetch_jira_issues_batch(
class JiraConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
jira_base_url: str,
project_key: str | None = None,
jira_project_url: str,
comment_email_blacklist: list[str] | None = None,
batch_size: int = INDEX_BATCH_SIZE,
# if a ticket has one of the labels specified in this list, we will just
@@ -169,12 +169,11 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP,
) -> None:
self.batch_size = batch_size
self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present
self.jira_project = project_key
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
self.jira_base, self._jira_project = extract_jira_project(jira_project_url)
self._jira_client: JIRA | None = None
self._comment_email_blacklist = comment_email_blacklist or []
self.labels_to_skip = set(labels_to_skip)
@property
def comment_email_blacklist(self) -> tuple:
@@ -189,9 +188,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
@property
def quoted_jira_project(self) -> str:
# Quote the project name to handle reserved words
if not self.jira_project:
return ""
return f'"{self.jira_project}"'
return f'"{self._jira_project}"'
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self._jira_client = build_jira_client(
@@ -200,14 +197,8 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
)
return None
def _get_jql_query(self) -> str:
"""Get the JQL query based on whether a specific project is set"""
if self.jira_project:
return f"project = {self.quoted_jira_project}"
return "" # Empty string means all accessible projects
def load_from_state(self) -> GenerateDocumentsOutput:
jql = self._get_jql_query()
jql = f"project = {self.quoted_jira_project}"
document_batch = []
for doc in fetch_jira_issues_batch(
@@ -234,10 +225,11 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
"%Y-%m-%d %H:%M"
)
base_jql = self._get_jql_query()
jql = (
f"{base_jql} AND " if base_jql else ""
) + f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'"
f"project = {self.quoted_jira_project} AND "
f"updated >= '{start_date_str}' AND "
f"updated <= '{end_date_str}'"
)
document_batch = []
for doc in fetch_jira_issues_batch(
@@ -260,7 +252,7 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
jql = self._get_jql_query()
jql = f"project = {self.quoted_jira_project}"
slim_doc_batch = []
for issue in _paginate_jql_search(
@@ -287,63 +279,43 @@ class JiraConnector(LoadConnector, PollConnector, SlimConnector):
if self._jira_client is None:
raise ConnectorMissingCredentialError("Jira")
# If a specific project is set, validate it exists
if self.jira_project:
try:
self.jira_client.project(self.jira_project)
except Exception as e:
status_code = getattr(e, "status_code", None)
if not self._jira_project:
raise ConnectorValidationError(
"Invalid connector settings: 'jira_project' must be provided."
)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
)
elif status_code == 404:
raise ConnectorValidationError(
f"Jira project not found with key: {self.jira_project}"
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
try:
self.jira_client.project(self._jira_project)
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
else:
# If no project specified, validate we can access the Jira API
try:
# Try to list projects to validate access
self.jira_client.projects()
except Exception as e:
status_code = getattr(e, "status_code", None)
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions to list projects (HTTP 403)."
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
except Exception as e:
status_code = getattr(e, "status_code", None)
raise RuntimeError(f"Unexpected Jira error during validation: {e}")
if status_code == 401:
raise CredentialExpiredError(
"Jira credential appears to be expired or invalid (HTTP 401)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your Jira token does not have sufficient permissions for this project (HTTP 403)."
)
elif status_code == 404:
raise ConnectorValidationError(
f"Jira project not found with key: {self._jira_project}"
)
elif status_code == 429:
raise ConnectorValidationError(
"Validation failed due to Jira rate-limits being exceeded. Please try again later."
)
else:
raise Exception(f"Unexpected Jira error during validation: {e}")
if __name__ == "__main__":
import os
connector = JiraConnector(
jira_base_url=os.environ["JIRA_BASE_URL"],
project_key=os.environ.get("JIRA_PROJECT_KEY"),
comment_email_blacklist=[],
os.environ["JIRA_PROJECT_URL"], comment_email_blacklist=[]
)
connector.load_credentials(
{
"jira_user_email": os.environ["JIRA_USER_EMAIL"],

View File

@@ -18,10 +18,6 @@ from slack_sdk.errors import SlackApiError
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
@@ -86,14 +82,14 @@ def get_channels(
get_public: bool = True,
get_private: bool = True,
) -> list[ChannelType]:
"""Get all channels in the workspace."""
"""Get all channels in the workspace"""
channels: list[dict[str, Any]] = []
channel_types = []
if get_public:
channel_types.append("public_channel")
if get_private:
channel_types.append("private_channel")
# Try fetching both public and private channels first:
# try getting private channels as well at first
try:
channels = _collect_paginated_channels(
client=client,
@@ -101,19 +97,19 @@ def get_channels(
channel_types=channel_types,
)
except SlackApiError as e:
logger.info(
f"Unable to fetch private channels due to: {e}. Trying again without private channels."
)
logger.info(f"Unable to fetch private channels due to - {e}")
logger.info("trying again without private channels")
if get_public:
channel_types = ["public_channel"]
else:
logger.warning("No channels to fetch.")
logger.warning("No channels to fetch")
return []
channels = _collect_paginated_channels(
client=client,
exclude_archived=exclude_archived,
channel_types=channel_types,
)
return channels
@@ -670,88 +666,6 @@ class SlackConnector(SlimConnector, CheckpointConnector):
)
return checkpoint
def validate_connector_settings(self) -> None:
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists.
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
try:
# 1) Validate connection to workspace
auth_response = self.client.auth_test()
if not auth_response.get("ok", False):
error_msg = auth_response.get(
"error", "Unknown error from Slack auth_test"
)
raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}")
# 2) Minimal test to confirm listing channels works
test_resp = self.client.conversations_list(
limit=1, types=["public_channel"]
)
if not test_resp.get("ok", False):
error_msg = test_resp.get("error", "Unknown error from Slack")
if error_msg == "invalid_auth":
raise ConnectorValidationError(
f"Invalid Slack bot token ({error_msg})."
)
elif error_msg == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({error_msg})."
)
raise UnexpectedValidationError(
f"Slack API returned a failure: {error_msg}"
)
# 3) If channels are specified, verify each is accessible
if self.channels:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,
get_public=True,
get_private=True,
)
# For quick lookups by name or ID, build a map:
accessible_channel_names = {ch["name"] for ch in accessible_channels}
accessible_channel_ids = {ch["id"] for ch in accessible_channels}
for user_channel in self.channels:
if (
user_channel not in accessible_channel_names
and user_channel not in accessible_channel_ids
):
raise ConnectorValidationError(
f"Channel '{user_channel}' not found or inaccessible in this workspace."
)
except SlackApiError as e:
slack_error = e.response.get("error", "")
if slack_error == "missing_scope":
raise InsufficientPermissionsError(
"Slack bot token lacks the necessary scope to list/access channels. "
"Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)."
)
elif slack_error == "invalid_auth":
raise CredentialExpiredError(
f"Invalid Slack bot token ({slack_error})."
)
elif slack_error == "not_authed":
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
raise UnexpectedValidationError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
except ConnectorValidationError as e:
raise e
except Exception as e:
raise UnexpectedValidationError(
f"Unexpected error during Slack settings validation: {e}"
)
if __name__ == "__main__":
import os

View File

@@ -72,7 +72,6 @@ def make_slack_api_rate_limited(
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
for _ in range(max_retries):
try:
# Make the API call

View File

@@ -5,7 +5,6 @@ from typing import Any
import msal # type: ignore
from office365.graph_client import GraphClient # type: ignore
from office365.runtime.client_request_exception import ClientRequestException # type: ignore
from office365.teams.channels.channel import Channel # type: ignore
from office365.teams.chats.messages.message import ChatMessage # type: ignore
from office365.teams.team import Team # type: ignore
@@ -13,10 +12,6 @@ from office365.teams.team import Team # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -284,50 +279,6 @@ class TeamsConnector(LoadConnector, PollConnector):
end_datetime = datetime.fromtimestamp(end, timezone.utc)
return self._fetch_from_teams(start=start_datetime, end=end_datetime)
def validate_connector_settings(self) -> None:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams credentials not loaded.")
try:
# Minimal call to confirm we can retrieve Teams
found_teams = self._get_all_teams()
except ClientRequestException as e:
status_code = e.response.status_code
if status_code == 401:
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials (401 Unauthorized)."
)
elif status_code == 403:
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()
if (
"unauthorized" in error_str
or "401" in error_str
or "invalid_grant" in error_str
):
raise CredentialExpiredError(
"Invalid or expired Microsoft Teams credentials."
)
elif "forbidden" in error_str or "403" in error_str:
raise InsufficientPermissionsError(
"App lacks required permissions to read from Microsoft Teams."
)
raise ConnectorValidationError(
f"Unexpected error during Teams validation: {e}"
)
if not found_teams:
raise ConnectorValidationError(
"No Teams found for the given credentials. "
"Either there are no Teams in this tenant, or your app does not have permission to view them."
)
if __name__ == "__main__":
connector = TeamsConnector(teams=os.environ["TEAMS"].split(","))

View File

@@ -25,12 +25,12 @@ from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_SECRET
from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_TOKEN_URL
from onyx.configs.app_configs import WEB_CONNECTOR_VALIDATE_URLS
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import ConnectorValidationError
from onyx.connectors.interfaces import CredentialExpiredError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import InsufficientPermissionsError
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import UnexpectedError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import read_pdf_file
@@ -42,10 +42,6 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
# Threshold for determining when to replace vs append iframe content
IFRAME_TEXT_LENGTH_THRESHOLD = 700
# Message indicating JavaScript is disabled, which often appears when scraping fails
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
@@ -142,8 +138,7 @@ def get_internal_links(
# Account for malformed backslashes in URLs
href = href.replace("\\", "/")
# "#!" indicates the page is using a hashbang URL, which is a client-side routing technique
if should_ignore_pound and "#" in href and "#!" not in href:
if should_ignore_pound and "#" in href:
href = href.split("#")[0]
if not is_valid_url(href):
@@ -293,7 +288,6 @@ class WebConnector(LoadConnector):
and converts them into documents"""
visited_links: set[str] = set()
to_visit: list[str] = self.to_visit_list
content_hashes = set()
if not to_visit:
raise ValueError("No URLs to visit")
@@ -308,30 +302,29 @@ class WebConnector(LoadConnector):
playwright, context = start_playwright()
restart_playwright = False
while to_visit:
initial_url = to_visit.pop()
if initial_url in visited_links:
current_url = to_visit.pop()
if current_url in visited_links:
continue
visited_links.add(initial_url)
visited_links.add(current_url)
try:
protected_url_check(initial_url)
protected_url_check(current_url)
except Exception as e:
last_error = f"Invalid URL {initial_url} due to {e}"
last_error = f"Invalid URL {current_url} due to {e}"
logger.warning(last_error)
continue
index = len(visited_links)
logger.info(f"{index}: Visiting {initial_url}")
logger.info(f"Visiting {current_url}")
try:
check_internet_connection(initial_url)
check_internet_connection(current_url)
if restart_playwright:
playwright, context = start_playwright()
restart_playwright = False
if initial_url.split(".")[-1] == "pdf":
if current_url.split(".")[-1] == "pdf":
# PDF files are not checked for links
response = requests.get(initial_url)
response = requests.get(current_url)
page_text, metadata = read_pdf_file(
file=io.BytesIO(response.content)
)
@@ -339,10 +332,10 @@ class WebConnector(LoadConnector):
doc_batch.append(
Document(
id=initial_url,
sections=[Section(link=initial_url, text=page_text)],
id=current_url,
sections=[Section(link=current_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=initial_url.split("/")[-1],
semantic_identifier=current_url.split("/")[-1],
metadata=metadata,
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
@@ -354,29 +347,21 @@ class WebConnector(LoadConnector):
continue
page = context.new_page()
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
page_response = page.goto(
initial_url,
timeout=30000, # 30 seconds
)
page_response = page.goto(current_url)
last_modified = (
page_response.header_value("Last-Modified")
if page_response
else None
)
final_url = page.url
if final_url != initial_url:
protected_url_check(final_url)
initial_url = final_url
if initial_url in visited_links:
logger.info(
f"{index}: {initial_url} redirected to {final_url} - already indexed"
)
final_page = page.url
if final_page != current_url:
logger.info(f"Redirected to {final_page}")
protected_url_check(final_page)
current_url = final_page
if current_url in visited_links:
logger.info("Redirected page already indexed")
continue
logger.info(f"{index}: {initial_url} redirected to {final_url}")
visited_links.add(initial_url)
visited_links.add(current_url)
if self.scroll_before_scraping:
scroll_attempts = 0
@@ -394,58 +379,26 @@ class WebConnector(LoadConnector):
soup = BeautifulSoup(content, "html.parser")
if self.recursive:
internal_links = get_internal_links(base_url, initial_url, soup)
internal_links = get_internal_links(base_url, current_url, soup)
for link in internal_links:
if link not in visited_links:
to_visit.append(link)
if page_response and str(page_response.status)[0] in ("4", "5"):
last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response"
last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response"
logger.info(last_error)
continue
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
"""For websites containing iframes that need to be scraped,
the code below can extract text from within these iframes.
"""
logger.debug(
f"{index}: Length of cleaned text {len(parsed_html.cleaned_text)}"
)
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
iframe_count = page.frame_locator("iframe").locator("html").count()
if iframe_count > 0:
iframe_texts = (
page.frame_locator("iframe")
.locator("html")
.all_inner_texts()
)
document_text = "\n".join(iframe_texts)
""" 700 is the threshold value for the length of the text extracted
from the iframe based on the issue faced """
if len(parsed_html.cleaned_text) < IFRAME_TEXT_LENGTH_THRESHOLD:
parsed_html.cleaned_text = document_text
else:
parsed_html.cleaned_text += "\n" + document_text
# Sometimes pages with #! will serve duplicate content
# There are also just other ways this can happen
hashed_text = hash((parsed_html.title, parsed_html.cleaned_text))
if hashed_text in content_hashes:
logger.info(
f"{index}: Skipping duplicate title + content for {initial_url}"
)
continue
content_hashes.add(hashed_text)
doc_batch.append(
Document(
id=initial_url,
id=current_url,
sections=[
Section(link=initial_url, text=parsed_html.cleaned_text)
Section(link=current_url, text=parsed_html.cleaned_text)
],
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or initial_url,
semantic_identifier=parsed_html.title or current_url,
metadata={},
doc_updated_at=_get_datetime_from_last_modified_header(
last_modified
@@ -457,7 +410,7 @@ class WebConnector(LoadConnector):
page.close()
except Exception as e:
last_error = f"Failed to fetch '{initial_url}': {e}"
last_error = f"Failed to fetch '{current_url}': {e}"
logger.exception(last_error)
playwright.stop()
restart_playwright = True
@@ -487,10 +440,7 @@ class WebConnector(LoadConnector):
"No URL configured. Please provide at least one valid URL."
)
if (
self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value
or self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value
):
if self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value:
return None
# We'll just test the first URL for connectivity and correctness
@@ -528,9 +478,7 @@ class WebConnector(LoadConnector):
)
else:
# Could be a 5xx or another error, treat as unexpected
raise UnexpectedValidationError(
f"Unexpected error validating '{test_url}': {e}"
)
raise UnexpectedError(f"Unexpected error validating '{test_url}': {e}")
if __name__ == "__main__":

View File

@@ -76,10 +76,6 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
provider_type=search_settings.provider_type,
index_name=search_settings.index_name,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
# Whether switching to this model requires re-indexing
background_reindex_enabled=search_settings.background_reindex_enabled,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,

View File

@@ -16,6 +16,7 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.server.api_key.models import APIKeyArgs
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -72,7 +73,7 @@ def insert_api_key(
# Get tenant_id from context var (will be default schema for single tenant)
tenant_id = get_current_tenant_id()
api_key = generate_api_key(tenant_id)
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
api_key_user_id = uuid.uuid4()
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER

View File

@@ -168,7 +168,7 @@ def get_chat_sessions_by_user(
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
stmt = stmt.order_by(desc(ChatSession.time_updated))
stmt = stmt.order_by(desc(ChatSession.time_created))
if deleted is not None:
stmt = stmt.where(ChatSession.deleted == deleted)
@@ -962,7 +962,6 @@ def translate_db_message_to_chat_message_detail(
chat_message.sub_questions
),
refined_answer_improvement=chat_message.refined_answer_improvement,
is_agentic=chat_message.is_agentic,
error=chat_message.error,
)

View File

@@ -1,111 +0,0 @@
from typing import List
from typing import Optional
from typing import Tuple
from uuid import UUID
from sqlalchemy import column
from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
def search_chat_sessions(
user_id: UUID | None,
db_session: Session,
query: Optional[str] = None,
page: int = 1,
page_size: int = 10,
include_deleted: bool = False,
include_onyxbot_flows: bool = False,
) -> Tuple[List[ChatSession], bool]:
"""
Fast full-text search on ChatSession + ChatMessage using tsvectors.
If no query is provided, returns the most recent chat sessions.
Otherwise, searches both chat messages and session descriptions.
Returns a tuple of (sessions, has_more) where has_more indicates if
there are additional results beyond the requested page.
"""
offset_val = (page - 1) * page_size
# If no query, just return the most recent sessions
if not query or not query.strip():
stmt = (
select(ChatSession)
.order_by(desc(ChatSession.time_created))
.offset(offset_val)
.limit(page_size + 1)
)
if user_id is not None:
stmt = stmt.where(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
stmt = stmt.where(ChatSession.deleted.is_(False))
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
sessions = result.scalars().all()
has_more = len(sessions) > page_size
if has_more:
sessions = sessions[:page_size]
return list(sessions), has_more
# Otherwise, proceed with full-text search
query = query.strip()
base_conditions = []
if user_id is not None:
base_conditions.append(ChatSession.user_id == user_id)
if not include_onyxbot_flows:
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
if not include_deleted:
base_conditions.append(ChatSession.deleted.is_(False))
message_tsv: ColumnClause = column("message_tsv")
description_tsv: ColumnClause = column("description_tsv")
ts_query = func.plainto_tsquery("english", query)
description_session_ids = (
select(ChatSession.id)
.where(*base_conditions)
.where(description_tsv.op("@@")(ts_query))
)
message_session_ids = (
select(ChatMessage.chat_session_id)
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
.where(*base_conditions)
.where(message_tsv.op("@@")(ts_query))
)
combined_ids = description_session_ids.union(message_session_ids).alias(
"combined_ids"
)
final_stmt = (
select(ChatSession)
.join(combined_ids, ChatSession.id == combined_ids.c.id)
.order_by(desc(ChatSession.time_created))
.distinct()
.offset(offset_val)
.limit(page_size + 1)
.options(joinedload(ChatSession.persona))
)
session_objs = db_session.execute(final_stmt).scalars().all()
has_more = len(session_objs) > page_size
if has_more:
session_objs = session_objs[:page_size]
return list(session_objs), has_more

View File

@@ -1,5 +1,4 @@
from datetime import datetime
from typing import TypeVarTuple
from fastapi import HTTPException
from sqlalchemy import delete
@@ -9,18 +8,15 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.db.connector import fetch_connector_by_id
from onyx.db.credentials import fetch_credential_by_id
from onyx.db.credentials import fetch_credential_by_id_for_user
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import IndexModelStatus
@@ -35,12 +31,10 @@ from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
R = TypeVarTuple("R")
def _add_user_filters(
stmt: Select[tuple[*R]], user: User | None, get_editable: bool = True
) -> Select[tuple[*R]]:
stmt: Select, user: User | None, get_editable: bool = True
) -> Select:
# If user is None and auth is disabled, assume the user is an admin
if (user is None and DISABLE_AUTH) or (user and user.role == UserRole.ADMIN):
return stmt
@@ -104,52 +98,17 @@ def get_connector_credential_pairs_for_user(
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
if eager_load_user:
assert (
eager_load_credential
), "eager_load_credential must be True if eager_load_user is True"
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
if eager_load_user:
load_opts = load_opts.joinedload(Credential.user)
stmt = stmt.options(load_opts)
stmt = stmt.options(joinedload(ConnectorCredentialPair.connector))
stmt = _add_user_filters(stmt, user, get_editable)
if ids:
stmt = stmt.where(ConnectorCredentialPair.id.in_(ids))
return list(db_session.scalars(stmt).unique().all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_connector_credential_pairs_for_user_parallel(
user: User | None,
get_editable: bool = True,
ids: list[int] | None = None,
eager_load_connector: bool = False,
eager_load_credential: bool = False,
eager_load_user: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_connector_credential_pairs_for_user(
db_session,
user,
get_editable,
ids,
eager_load_connector,
eager_load_credential,
eager_load_user,
)
return list(db_session.scalars(stmt).all())
def get_connector_credential_pairs(
@@ -192,16 +151,6 @@ def get_cc_pair_groups_for_ids(
return list(db_session.scalars(stmt).all())
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_cc_pair_groups_for_ids_parallel(
cc_pair_ids: list[int],
) -> list[UserGroup__ConnectorCredentialPair]:
with get_session_context_manager() as db_session:
return get_cc_pair_groups_for_ids(db_session, cc_pair_ids)
def get_connector_credential_pair_for_user(
db_session: Session,
connector_id: int,
@@ -452,8 +401,8 @@ def add_credential_to_connector(
# If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user
if seeding_flow:
credential = fetch_credential_by_id(
credential_id=credential_id,
db_session=db_session,
credential_id=credential_id,
)
else:
credential = fetch_credential_by_id_for_user(

View File

@@ -169,8 +169,8 @@ def fetch_credential_by_id_for_user(
def fetch_credential_by_id(
credential_id: int,
db_session: Session,
credential_id: int,
) -> Credential | None:
stmt = select(Credential).distinct()
stmt = stmt.where(Credential.id == credential_id)
@@ -360,13 +360,18 @@ def backend_update_credential_json(
db_session.commit()
def _delete_credential_internal(
credential: Credential,
def delete_credential(
credential_id: int,
user: User | None,
db_session: Session,
force: bool = False,
) -> None:
"""Internal utility function to handle the actual deletion of a credential"""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
associated_connectors = (
db_session.query(ConnectorCredentialPair)
.filter(ConnectorCredentialPair.credential_id == credential_id)
@@ -411,43 +416,14 @@ def _delete_credential_internal(
db_session.commit()
def delete_credential_for_user(
credential_id: int,
user: User,
db_session: Session,
force: bool = False,
) -> None:
"""Delete a credential that belongs to a specific user"""
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
if credential is None:
raise ValueError(
f"Credential by provided id {credential_id} does not exist or does not belong to user"
)
_delete_credential_internal(credential, credential_id, db_session, force)
def delete_credential(
credential_id: int,
db_session: Session,
force: bool = False,
) -> None:
"""Delete a credential regardless of ownership (admin function)"""
credential = fetch_credential_by_id(credential_id, db_session)
if credential is None:
raise ValueError(f"Credential by provided id {credential_id} does not exist")
_delete_credential_internal(credential, credential_id, db_session, force)
def create_initial_public_credential(db_session: Session) -> None:
error_msg = (
"DB is not in a valid initial state."
"There must exist an empty public credential for data connectors that do not require additional Auth."
)
first_credential = fetch_credential_by_id(
credential_id=PUBLIC_CREDENTIAL_ID,
db_session=db_session,
credential_id=PUBLIC_CREDENTIAL_ID,
)
if first_credential is not None:

View File

@@ -24,7 +24,6 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.feedback import delete_document_feedback_for_documents__no_commit
@@ -230,12 +229,12 @@ def get_document_connector_counts(
def get_document_counts_for_cc_pairs(
db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier]
db_session: Session, cc_pair_identifiers: list[ConnectorCredentialPairIdentifier]
) -> Sequence[tuple[int, int, int]]:
"""Returns a sequence of tuples of (connector_id, credential_id, document count)"""
# Prepare a list of (connector_id, credential_id) tuples
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs]
cc_ids = [(x.connector_id, x.credential_id) for x in cc_pair_identifiers]
stmt = (
select(
@@ -261,16 +260,6 @@ def get_document_counts_for_cc_pairs(
return db_session.execute(stmt).all() # type: ignore
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_document_counts_for_cc_pairs_parallel(
cc_pairs: list[ConnectorCredentialPairIdentifier],
) -> Sequence[tuple[int, int, int]]:
with get_session_context_manager() as db_session:
return get_document_counts_for_cc_pairs(db_session, cc_pairs)
def get_access_info_for_document(
db_session: Session,
document_id: str,

View File

@@ -218,7 +218,6 @@ class SqlEngine:
final_engine_kwargs.update(engine_kwargs)
logger.info(f"Creating engine with kwargs: {final_engine_kwargs}")
# echo=True here for inspecting all emitted db queries
engine = create_engine(connection_string, **final_engine_kwargs)
if USE_IAM_AUTH:
@@ -258,11 +257,11 @@ class SqlEngine:
cls._engine = None
def get_all_tenant_ids() -> list[str]:
def get_all_tenant_ids() -> list[str] | list[None]:
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
if not MULTI_TENANT:
return [POSTGRES_DEFAULT_SCHEMA]
return [None]
with get_session_with_shared_schema() as session:
result = session.execute(
@@ -417,7 +416,7 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
@contextmanager
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
"""
Generate a database session for a specific tenant.
"""

View File

@@ -63,9 +63,6 @@ class IndexModelStatus(str, PyEnum):
PRESENT = "PRESENT"
FUTURE = "FUTURE"
def is_current(self) -> bool:
return self == IndexModelStatus.PRESENT
class ChatSessionSharedStatus(str, PyEnum):
PUBLIC = "public"
@@ -86,11 +83,3 @@ class AccessType(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
SYNC = "sync"
class EmbeddingPrecision(str, PyEnum):
# matches vespa tensor type
# only support float / bfloat16 for now, since there's not a
# good reason to specify anything else
BFLOAT16 = "bfloat16"
FLOAT = "float"

View File

@@ -2,7 +2,6 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import TypeVarTuple
from sqlalchemy import and_
from sqlalchemy import delete
@@ -10,13 +9,9 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import contains_eager
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from sqlalchemy.sql import Select
from onyx.connectors.models import ConnectorFailure
from onyx.db.engine import get_session_context_manager
from onyx.db.models import IndexAttempt
from onyx.db.models import IndexAttemptError
from onyx.db.models import IndexingStatus
@@ -373,33 +368,19 @@ def get_latest_index_attempts_by_status(
return db_session.execute(stmt).scalars().all()
T = TypeVarTuple("T")
def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]:
return stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
def get_latest_index_attempts(
secondary_index: bool,
db_session: Session,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
ids_stmt = select(
IndexAttempt.connector_credential_pair_id,
func.max(IndexAttempt.id).label("max_id"),
).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
ids_stmt = ids_stmt.where(SearchSettings.status == status)
if only_finished:
ids_stmt = _add_only_finished_clause(ids_stmt)
if secondary_index:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.FUTURE)
else:
ids_stmt = ids_stmt.where(SearchSettings.status == IndexModelStatus.PRESENT)
ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id)
ids_subquery = ids_stmt.subquery()
@@ -414,53 +395,7 @@ def get_latest_index_attempts(
.where(IndexAttempt.id == ids_subquery.c.max_id)
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
if eager_load_cc_pair:
stmt = stmt.options(
joinedload(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return db_session.execute(stmt).scalars().unique().all()
# For use with our thread-level parallelism utils. Note that any relationships
# you wish to use MUST be eagerly loaded, as the session will not be available
# after this function to allow lazy loading.
def get_latest_index_attempts_parallel(
secondary_index: bool,
eager_load_cc_pair: bool = False,
only_finished: bool = False,
) -> Sequence[IndexAttempt]:
with get_session_context_manager() as db_session:
return get_latest_index_attempts(
secondary_index,
db_session,
eager_load_cc_pair,
only_finished,
)
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = _add_only_finished_clause(stmt)
status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT
stmt = stmt.join(SearchSettings).where(SearchSettings.status == status)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
return db_session.execute(stmt).scalars().all()
def count_index_attempts_for_connector(
@@ -518,12 +453,37 @@ def get_paginated_index_attempts_for_cc_pair_id(
# Apply pagination
stmt = stmt.offset(page * page_size).limit(page_size)
stmt = stmt.options(
contains_eager(IndexAttempt.connector_credential_pair),
joinedload(IndexAttempt.error_rows),
)
return list(db_session.execute(stmt).scalars().unique().all())
return list(db_session.execute(stmt).scalars().all())
def get_latest_index_attempt_for_cc_pair_id(
db_session: Session,
connector_credential_pair_id: int,
secondary_index: bool,
only_finished: bool = True,
) -> IndexAttempt | None:
stmt = select(IndexAttempt)
stmt = stmt.where(
IndexAttempt.connector_credential_pair_id == connector_credential_pair_id,
)
if only_finished:
stmt = stmt.where(
IndexAttempt.status.not_in(
[IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS]
),
)
if secondary_index:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.FUTURE
)
else:
stmt = stmt.join(SearchSettings).where(
SearchSettings.status == IndexModelStatus.PRESENT
)
stmt = stmt.order_by(desc(IndexAttempt.time_created))
stmt = stmt.limit(1)
return db_session.execute(stmt).scalar_one_or_none()
def get_index_attempts_for_cc_pair(

View File

@@ -7,7 +7,6 @@ from typing import Optional
from uuid import uuid4
from pydantic import BaseModel
from sqlalchemy.orm import validates
from typing_extensions import TypedDict # noreorder
from uuid import UUID
@@ -26,7 +25,6 @@ from sqlalchemy import ForeignKey
from sqlalchemy import func
from sqlalchemy import Index
from sqlalchemy import Integer
from sqlalchemy import Sequence
from sqlalchemy import String
from sqlalchemy import Text
@@ -46,13 +44,7 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.db.enums import (
AccessType,
EmbeddingPrecision,
IndexingMode,
SyncType,
SyncStatus,
)
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import TokenRateLimitScope
@@ -213,10 +205,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
@property
def password_configured(self) -> bool:
"""
@@ -722,23 +710,6 @@ class SearchSettings(Base):
ForeignKey("embedding_provider.provider_type"), nullable=True
)
# Whether switching to this model should re-index all connectors in the background
# if no re-index is needed, will be ignored. Only used during the switch-over process.
background_reindex_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
# allows for quantization -> less memory usage for a small performance hit
embedding_precision: Mapped[EmbeddingPrecision] = mapped_column(
Enum(EmbeddingPrecision, native_enum=False)
)
# can be used to reduce dimensionality of vectors and save memory with
# a small performance hit. More details in the `Reducing embedding dimensions`
# section here:
# https://platform.openai.com/docs/guides/embeddings#embedding-models
# If not specified, will just use the model_dim without any reduction.
# NOTE: this is only currently available for OpenAI models
reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
@@ -820,12 +791,6 @@ class SearchSettings(Base):
self.multipass_indexing, self.model_name, self.provider_type
)
@property
def final_embedding_dim(self) -> int:
if self.reduced_dimension:
return self.reduced_dimension
return self.model_dim
@staticmethod
def can_use_large_chunks(
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
@@ -1790,7 +1755,6 @@ class ChannelConfig(TypedDict):
channel_name: str | None # None for default channel config
respond_tag_only: NotRequired[bool] # defaults to False
respond_to_bots: NotRequired[bool] # defaults to False
is_ephemeral: NotRequired[bool] # defaults to False
respond_member_group_list: NotRequired[list[str]]
answer_filters: NotRequired[list[AllowedAnswerFilters]]
# If None then no follow up
@@ -2305,10 +2269,6 @@ class UserTenantMapping(Base):
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
# This is a mapping from tenant IDs to anonymous user paths
class TenantAnonymousUserPath(Base):

View File

@@ -100,14 +100,9 @@ def _add_user_filters(
.correlate(Persona)
)
else:
# Group the public persona conditions
public_condition = (Persona.is_public == True) & ( # noqa: E712
Persona.is_visible == True # noqa: E712
)
where_clause |= public_condition
where_clause |= Persona.is_public == True # noqa: E712
where_clause &= Persona.is_visible == True # noqa: E712
where_clause |= Persona__User.user_id == user.id
where_clause |= Persona.user_id == user.id
return stmt.where(where_clause)
@@ -209,21 +204,13 @@ def create_update_persona(
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
is_default_persona: bool | None = create_persona_request.is_default_persona
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
if user:
# Curators can edit default personas, but not make them
if (
user.role == UserRole.CURATOR
or user.role == UserRole.GLOBAL_CURATOR
):
is_default_persona = None
elif user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
if user and user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
persona = upsert_persona(
persona_id=persona_id,
@@ -249,7 +236,7 @@ def create_update_persona(
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=is_default_persona,
is_default_persona=create_persona_request.is_default_persona,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -436,7 +423,7 @@ def upsert_persona(
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool | None = None,
is_default_persona: bool = False,
label_ids: list[int] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
@@ -531,11 +518,7 @@ def upsert_persona(
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None
else existing_persona.is_default_persona
)
existing_persona.is_default_persona = is_default_persona
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -587,9 +570,7 @@ def upsert_persona(
display_priority=display_priority,
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona
if is_default_persona is not None
else False,
is_default_persona=is_default_persona,
labels=labels or [],
)
db_session.add(new_persona)

View File

@@ -14,7 +14,6 @@ from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from onyx.context.search.models import SavedSearchSettings
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_embedding_provider
from onyx.db.models import CloudEmbeddingProvider
from onyx.db.models import IndexAttempt
@@ -60,15 +59,12 @@ def create_search_settings(
index_name=search_settings.index_name,
provider_type=search_settings.provider_type,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
background_reindex_enabled=search_settings.background_reindex_enabled,
)
db_session.add(embedding_model)
@@ -309,7 +305,6 @@ def get_old_default_embedding_model() -> IndexingSetting:
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
@@ -327,7 +322,6 @@ def get_new_default_embedding_model() -> IndexingSetting:
return IndexingSetting(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,

View File

@@ -8,12 +8,10 @@ from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_default_document_index
from onyx.key_value_store.factory import get_kv_store
from onyx.utils.logger import setup_logger
@@ -21,49 +19,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _perform_index_swap(
db_session: Session,
current_search_settings: SearchSettings,
secondary_search_settings: SearchSettings,
all_cc_pairs: list[ConnectorCredentialPair],
) -> None:
"""Swap the indices and expire the old one."""
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if len(all_cc_pairs) > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
# remove the old index from the vector db
document_index = get_default_document_index(secondary_search_settings, None)
document_index.ensure_indices_exist(
primary_embedding_dim=secondary_search_settings.final_embedding_dim,
primary_embedding_precision=secondary_search_settings.embedding_precision,
# just finished swap, no more secondary index
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
def check_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one.
@@ -71,45 +27,52 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
old_search_settings = None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
secondary_search_settings = get_secondary_search_settings(db_session)
search_settings = get_secondary_search_settings(db_session)
if not secondary_search_settings:
if not search_settings:
return None
# If the secondary search settings are not configured to reindex in the background,
# we can just swap over instantly
if not secondary_search_settings.background_reindex_enabled:
current_search_settings = get_current_search_settings(db_session)
_perform_index_swap(
db_session=db_session,
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
)
return current_search_settings
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id=secondary_search_settings.id, db_session=db_session
search_settings_id=search_settings.id, db_session=db_session
)
# Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this
# function is correct. The unique_cc_indexings are specifically for the existing cc-pairs
old_search_settings = None
if unique_cc_indexings > cc_pair_count:
logger.error("More unique indexings than cc pairs, should not occur")
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
# Swap indices
current_search_settings = get_current_search_settings(db_session)
_perform_index_swap(
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
)
old_search_settings = current_search_settings
update_search_settings_status(
search_settings=search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if cc_pair_count > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
old_search_settings = current_search_settings
return old_search_settings

View File

@@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
# Vespa's Document API.
def get_document_chunk_ids(
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
tenant_id: str,
tenant_id: str | None,
large_chunks_enabled: bool,
) -> list[UUID]:
doc_chunk_ids = []
@@ -139,7 +139,7 @@ def get_uuid_from_chunk_info(
*,
document_id: str,
chunk_id: int,
tenant_id: str,
tenant_id: str | None,
large_chunk_id: int | None = None,
) -> UUID:
"""NOTE: be VERY carefuly about changing this function. If changed without a migration,
@@ -154,7 +154,7 @@ def get_uuid_from_chunk_info(
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
)
unique_identifier_string = "_".join([doc_str, chunk_index])
if MULTI_TENANT:
if tenant_id and MULTI_TENANT:
unique_identifier_string += "_" + tenant_id
uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)

View File

@@ -6,7 +6,6 @@ from typing import Any
from onyx.access.models import DocumentAccess
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.model_server_models import Embedding
@@ -44,7 +43,7 @@ class IndexBatchParams:
doc_id_to_previous_chunk_cnt: dict[str, int | None]
doc_id_to_new_chunk_cnt: dict[str, int]
tenant_id: str
tenant_id: str | None
large_chunks_enabled: bool
@@ -146,21 +145,17 @@ class Verifiable(abc.ABC):
@abc.abstractmethod
def ensure_indices_exist(
self,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
) -> None:
"""
Verify that the document index exists and is consistent with the expectations in the code.
Parameters:
- primary_embedding_dim: Vector dimensionality for the vector similarity part of the search
- primary_embedding_precision: Precision of the vector similarity part of the search
- index_embedding_dim: Vector dimensionality for the vector similarity part of the search
- secondary_index_embedding_dim: Vector dimensionality of the secondary index being built
behind the scenes. The secondary index should only be built when switching
embedding models therefore this dim should be different from the primary index.
- secondary_index_embedding_precision: Precision of the vector similarity part of the secondary index
"""
raise NotImplementedError
@@ -169,7 +164,6 @@ class Verifiable(abc.ABC):
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
"""
Register multitenant indices with the document index.
@@ -228,7 +222,7 @@ class Deletable(abc.ABC):
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
) -> int:
"""
@@ -255,7 +249,7 @@ class Updatable(abc.ABC):
self,
doc_id: str,
*,
tenant_id: str,
tenant_id: str | None,
chunk_count: int | None,
fields: VespaDocumentFields,
) -> int:
@@ -276,7 +270,9 @@ class Updatable(abc.ABC):
raise NotImplementedError
@abc.abstractmethod
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
def update(
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
) -> None:
"""
Updates some set of chunks. The document and fields to update are specified in the update
requests. Each update request in the list applies its changes to a list of document ids.

View File

@@ -37,7 +37,7 @@ schema DANSWER_CHUNK_NAME {
summary: dynamic
}
# Title embedding (x1)
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) {
field title_embedding type tensor<float>(x[VARIABLE_DIM]) {
indexing: attribute | index
attribute {
distance-metric: angular
@@ -45,7 +45,7 @@ schema DANSWER_CHUNK_NAME {
}
# Content embeddings (chunk + optional mini chunks embeddings)
# "t" and "x" are arbitrary names, not special keywords
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) {
field embeddings type tensor<float>(t{},x[VARIABLE_DIM]) {
indexing: attribute | index
attribute {
distance-metric: angular

Some files were not shown because too many files have changed in this diff Show More