mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
39 Commits
fix_openap
...
fix_error_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9cab91ff92 | ||
|
|
692058092f | ||
|
|
e88325aad6 | ||
|
|
7490250e91 | ||
|
|
e5369fcef8 | ||
|
|
b0f00953bc | ||
|
|
f6a75c86c6 | ||
|
|
ed9989282f | ||
|
|
e80a0f2716 | ||
|
|
909403a648 | ||
|
|
cd84b65011 | ||
|
|
413f21cec0 | ||
|
|
eb369384a7 | ||
|
|
0a24dbc52c | ||
|
|
a7ba0da8cc | ||
|
|
aaced6d551 | ||
|
|
4c230f92ea | ||
|
|
07d75b04d1 | ||
|
|
a8d10750c1 | ||
|
|
85e3ed57f1 | ||
|
|
e10cc8ccdb | ||
|
|
7018bc974b | ||
|
|
9c9075d71d | ||
|
|
338e084062 | ||
|
|
2f64031f5c | ||
|
|
abb74f2eaa | ||
|
|
a3e3d83b7e | ||
|
|
4dc88ca037 | ||
|
|
11e7e1c4d6 | ||
|
|
f2d74ce540 | ||
|
|
25389c5120 | ||
|
|
ad0721ecd8 | ||
|
|
426a8842ae | ||
|
|
a98dcbc7de | ||
|
|
6f389dc100 | ||
|
|
d56177958f | ||
|
|
0e42ae9024 | ||
|
|
ce2b4de245 | ||
|
|
a515aa78d2 |
1
.github/CODEOWNERS
vendored
Normal file
1
.github/CODEOWNERS
vendored
Normal file
@@ -0,0 +1 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
94
.github/workflows/nightly-scan-licenses.yml
vendored
94
.github/workflows/nightly-scan-licenses.yml
vendored
@@ -53,24 +53,90 @@ jobs:
|
||||
exclude: '(?i)^(pylint|aio[-_]*).*'
|
||||
|
||||
- name: Print report
|
||||
if: ${{ always() }}
|
||||
if: always()
|
||||
run: echo "${{ steps.license_check_report.outputs.report }}"
|
||||
|
||||
- name: Install npm dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Trivy vulnerability scanner in repo mode
|
||||
uses: aquasecurity/trivy-action@0.28.0
|
||||
with:
|
||||
scan-type: fs
|
||||
scanners: license
|
||||
format: table
|
||||
# format: sarif
|
||||
# output: trivy-results.sarif
|
||||
severity: HIGH,CRITICAL
|
||||
|
||||
# - name: Upload Trivy scan results to GitHub Security tab
|
||||
# uses: github/codeql-action/upload-sarif@v3
|
||||
# be careful enabling the sarif and upload as it may spam the security tab
|
||||
# with a huge amount of items. Work out the issues before enabling upload.
|
||||
# - name: Run Trivy vulnerability scanner in repo mode
|
||||
# if: always()
|
||||
# uses: aquasecurity/trivy-action@0.29.0
|
||||
# with:
|
||||
# sarif_file: trivy-results.sarif
|
||||
# scan-type: fs
|
||||
# scan-ref: .
|
||||
# scanners: license
|
||||
# format: table
|
||||
# severity: HIGH,CRITICAL
|
||||
# # format: sarif
|
||||
# # output: trivy-results.sarif
|
||||
#
|
||||
# # - name: Upload Trivy scan results to GitHub Security tab
|
||||
# # uses: github/codeql-action/upload-sarif@v3
|
||||
# # with:
|
||||
# # sarif_file: trivy-results.sarif
|
||||
|
||||
scan-trivy:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=2cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
# Backend
|
||||
- name: Pull backend docker image
|
||||
run: docker pull onyxdotapp/onyx-backend:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on backend
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-backend:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0 # Set to 1 if we want a failed scan to fail the workflow
|
||||
|
||||
# Web server
|
||||
- name: Pull web server docker image
|
||||
run: docker pull onyxdotapp/onyx-web-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner on web server
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-web-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
|
||||
# Model server
|
||||
- name: Pull model server docker image
|
||||
run: docker pull onyxdotapp/onyx-model-server:latest
|
||||
|
||||
- name: Run Trivy vulnerability scanner
|
||||
uses: aquasecurity/trivy-action@0.29.0
|
||||
env:
|
||||
TRIVY_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-db:2'
|
||||
TRIVY_JAVA_DB_REPOSITORY: 'public.ecr.aws/aquasecurity/trivy-java-db:1'
|
||||
with:
|
||||
image-ref: onyxdotapp/onyx-model-server:latest
|
||||
scanners: license
|
||||
severity: HIGH,CRITICAL
|
||||
vuln-type: library
|
||||
exit-code: 0
|
||||
84
backend/alembic/versions/3bd4c84fe72f_improved_index.py
Normal file
84
backend/alembic/versions/3bd4c84fe72f_improved_index.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""improved index
|
||||
|
||||
Revision ID: 3bd4c84fe72f
|
||||
Revises: 8f43500ee275
|
||||
Create Date: 2025-02-26 13:07:56.217791
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3bd4c84fe72f"
|
||||
down_revision = "8f43500ee275"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
# NOTE:
|
||||
# This migration addresses issues with the previous migration (8f43500ee275) which caused
|
||||
# an outage by creating an index without using CONCURRENTLY. This migration:
|
||||
#
|
||||
# 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes
|
||||
# 2. Uses CONCURRENTLY for all index creation to prevent table locking
|
||||
# 3. Explicitly manages transactions with COMMIT statements to allow CONCURRENTLY to work
|
||||
# (see: https://www.postgresql.org/docs/9.4/sql-createindex.html#SQL-CREATEINDEX-CONCURRENTLY)
|
||||
# (see: https://github.com/sqlalchemy/alembic/issues/277)
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
ADD COLUMN message_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit the current transaction before creating concurrent indexes
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message
|
||||
USING GIN (message_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
# Also add a stored tsvector column for chat_session.description
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_session
|
||||
ADD COLUMN description_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Commit again before creating the second concurrent index
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session
|
||||
USING GIN (description_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the indexes first (use CONCURRENTLY for dropping too)
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;")
|
||||
|
||||
op.execute("COMMIT")
|
||||
op.execute("DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;")
|
||||
|
||||
# Then drop the columns
|
||||
op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;")
|
||||
op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;")
|
||||
|
||||
op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;")
|
||||
@@ -18,12 +18,13 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX idx_chat_message_message_lower
|
||||
ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
"""
|
||||
)
|
||||
# op.execute(
|
||||
# """
|
||||
# CREATE INDEX idx_chat_message_message_lower
|
||||
# ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
# """
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""force lowercase all users
|
||||
|
||||
Revision ID: f11b408e39d3
|
||||
Revises: 3bd4c84fe72f
|
||||
Create Date: 2025-02-26 17:04:55.683500
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f11b408e39d3"
|
||||
down_revision = "3bd4c84fe72f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing user emails to lowercase
|
||||
from alembic import op
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
|
||||
# 2) Add a check constraint to ensure emails are always lowercase
|
||||
op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
from alembic import op
|
||||
|
||||
op.drop_constraint("ensure_lowercase_email", "user", type_="check")
|
||||
@@ -0,0 +1,42 @@
|
||||
"""lowercase multi-tenant user auth
|
||||
|
||||
Revision ID: 34e3630c7f32
|
||||
Revises: a4f6ee863c47
|
||||
Create Date: 2025-02-26 15:03:01.211894
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "34e3630c7f32"
|
||||
down_revision = "a4f6ee863c47"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Convert all existing rows to lowercase
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE user_tenant_mapping
|
||||
SET email = LOWER(email)
|
||||
"""
|
||||
)
|
||||
# 2) Add a check constraint so that emails cannot be written in uppercase
|
||||
op.create_check_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
"email = LOWER(email)",
|
||||
schema="public",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the check constraint
|
||||
op.drop_constraint(
|
||||
"ensure_lowercase_email",
|
||||
"user_tenant_mapping",
|
||||
schema="public",
|
||||
type_="check",
|
||||
)
|
||||
@@ -5,11 +5,9 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.db.chat import delete_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -18,10 +16,8 @@ logger = setup_logger()
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -35,24 +31,19 @@ def perform_ttl_management_task(
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -60,9 +51,9 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -59,10 +59,14 @@ SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", ""
|
||||
)
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", ""
|
||||
)
|
||||
OAUTH_JIRA_CLOUD_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_ID", "")
|
||||
OAUTH_JIRA_CLOUD_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLOUD_CLIENT_SECRET", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "")
|
||||
OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
|
||||
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
|
||||
|
||||
@@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
|
||||
)
|
||||
|
||||
# if the user is a curator in any of their groups, set their role to CURATOR
|
||||
# otherwise, set their role to BASIC
|
||||
# otherwise, set their role to BASIC only if they were previously a CURATOR
|
||||
if curator_relationships:
|
||||
user.role = UserRole.CURATOR
|
||||
elif user.role == UserRole.CURATOR:
|
||||
@@ -631,7 +631,16 @@ def update_user_group(
|
||||
removed_users = db_session.scalars(
|
||||
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
|
||||
).unique()
|
||||
_validate_curator_status__no_commit(db_session, list(removed_users))
|
||||
|
||||
# Filter out admin and global curator users before validating curator status
|
||||
users_to_validate = [
|
||||
user
|
||||
for user in removed_users
|
||||
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
|
||||
]
|
||||
|
||||
if users_to_validate:
|
||||
_validate_curator_status__no_commit(db_session, users_to_validate)
|
||||
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
@@ -9,12 +9,16 @@ from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GR
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -342,7 +346,8 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -354,7 +359,11 @@ def confluence_doc_sync(
|
||||
confluence_connector = ConfluenceConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
confluence_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), "confluence", cc_pair.credential_id
|
||||
)
|
||||
confluence_connector.set_credentials_provider(provider)
|
||||
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME
|
||||
from onyx.background.error_logging import emit_background_error
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
get_user_email_from_username__server,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -61,13 +63,27 @@ def _build_group_member_email_map(
|
||||
|
||||
|
||||
def confluence_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
confluence_client = build_confluence_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
is_cloud=cc_pair.connector.connector_specific_config.get("is_cloud", False),
|
||||
wiki_base=cc_pair.connector.connector_specific_config["wiki_base"],
|
||||
)
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id)
|
||||
is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False)
|
||||
wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"]
|
||||
url = wiki_base.rstrip("/")
|
||||
|
||||
probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
confluence_client = OnyxConfluence(is_cloud, url, provider)
|
||||
confluence_client._probe_connection(**probe_kwargs)
|
||||
confluence_client._initialize_connection(**final_kwargs)
|
||||
|
||||
group_member_email_map = _build_group_member_email_map(
|
||||
confluence_client=confluence_client,
|
||||
|
||||
@@ -32,7 +32,8 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -145,7 +145,8 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -119,6 +119,7 @@ def _build_onyx_groups(
|
||||
|
||||
|
||||
def gdrive_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> list[ExternalUserGroup]:
|
||||
# Initialize connector and build credential/service objects
|
||||
|
||||
@@ -123,7 +123,8 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
|
||||
@@ -28,6 +28,7 @@ DocSyncFuncType = Callable[
|
||||
|
||||
GroupSyncFuncType = Callable[
|
||||
[
|
||||
str,
|
||||
ConnectorCredentialPair,
|
||||
],
|
||||
list[ExternalUserGroup],
|
||||
|
||||
@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.oauth import router as oauth_router
|
||||
from ee.onyx.server.oauth.api import router as oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@@ -152,4 +152,8 @@ def get_application() -> FastAPI:
|
||||
# environment variable. Used to automate deployment for multiple environments.
|
||||
seed_db()
|
||||
|
||||
# for debugging discovered routes
|
||||
# for route in application.router.routes:
|
||||
# print(f"Path: {route.path}, Methods: {route.methods}")
|
||||
|
||||
return application
|
||||
|
||||
@@ -1,629 +0,0 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/oauth")
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
"""work in progress"""
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.onyx.app/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
session: str
|
||||
|
||||
if connector == DocumentSource.SLACK:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
# email=user.email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": SlackOAuth.REDIRECT_URI,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Work in progress
|
||||
# @router.post("/connector/confluence/callback")
|
||||
# def handle_confluence_oauth_callback(
|
||||
# code: str,
|
||||
# state: str,
|
||||
# user: User = Depends(current_user),
|
||||
# db_session: Session = Depends(get_session),
|
||||
# tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
# ) -> JSONResponse:
|
||||
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
# raise HTTPException(
|
||||
# status_code=500,
|
||||
# detail="Confluence client ID or client secret is not configured."
|
||||
# )
|
||||
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# # recover the state
|
||||
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
|
||||
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
|
||||
|
||||
# # Convert bytes back to a UUID
|
||||
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
# oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
# result = r.get(r_key)
|
||||
# if not result:
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
|
||||
# )
|
||||
|
||||
# try:
|
||||
# session = ConfluenceCloudOAuth.parse_session(result)
|
||||
|
||||
# # Exchange the authorization code for an access token
|
||||
# response = requests.post(
|
||||
# ConfluenceCloudOAuth.TOKEN_URL,
|
||||
# headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
# data={
|
||||
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
# "code": code,
|
||||
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
|
||||
# },
|
||||
# )
|
||||
|
||||
# response_data = response.json()
|
||||
|
||||
# if not response_data.get("ok"):
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
|
||||
# )
|
||||
|
||||
# # Extract token and team information
|
||||
# access_token: str = response_data.get("access_token")
|
||||
# team_id: str = response_data.get("team", {}).get("id")
|
||||
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
# credential_info = CredentialBase(
|
||||
# credential_json={"slack_bot_token": access_token},
|
||||
# admin_public=True,
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# name="Confluence OAuth",
|
||||
# )
|
||||
|
||||
# logger.info(f"Slack access token: {access_token}")
|
||||
|
||||
# credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# logger.info(f"new_credential_id={credential.id}")
|
||||
# except Exception as e:
|
||||
# return JSONResponse(
|
||||
# status_code=500,
|
||||
# content={
|
||||
# "success": False,
|
||||
# "message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
# },
|
||||
# )
|
||||
# finally:
|
||||
# r.delete(r_key)
|
||||
|
||||
# # return the result
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "success": True,
|
||||
# "message": "Slack OAuth completed successfully.",
|
||||
# "team_id": team_id,
|
||||
# "authed_user_id": authed_user_id,
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client()
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
session: GoogleDriveOAuth.OAuthSession
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": GoogleDriveOAuth.REDIRECT_URI,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
91
backend/ee/onyx/server/oauth/api.py
Normal file
91
backend/ee/onyx/server/oauth/api.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import base64
|
||||
import uuid
|
||||
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth
|
||||
from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth
|
||||
from ee.onyx.server.oauth.slack import SlackOAuth
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_admin_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
# create random oauth state param for security and to retrieve user data later
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# urlsafe b64 encode the uuid for the oauth url
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
session: str | None = None
|
||||
if connector == DocumentSource.SLACK:
|
||||
if not DEV_MODE:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state)
|
||||
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.CONFLUENCE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = ConfluenceCloudOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
if not DEV_MODE:
|
||||
oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
session = GoogleDriveOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
if not session:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"The document source type {connector} failed to generate an OAuth session.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# store important session state to retrieve when the user is redirected back
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
3
backend/ee/onyx/server/oauth/api_router.py
Normal file
3
backend/ee/onyx/server/oauth/api_router.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from fastapi import APIRouter
|
||||
|
||||
router: APIRouter = APIRouter(prefix="/oauth")
|
||||
361
backend/ee/onyx/server/oauth/confluence_cloud.py
Normal file
361
backend/ee/onyx/server/oauth/confluence_cloud.py
Normal file
@@ -0,0 +1,361 @@
|
||||
import base64
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import update_credential_json
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ConfluenceCloudOAuth:
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
class AccessibleResources(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
url: str
|
||||
scopes: list[str]
|
||||
avatarUrl: str
|
||||
|
||||
CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL
|
||||
|
||||
ACCESSIBLE_RESOURCE_URL = (
|
||||
"https://api.atlassian.com/oauth/token/accessible-resources"
|
||||
)
|
||||
|
||||
# All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
CONFLUENCE_OAUTH_SCOPE = (
|
||||
# classic scope
|
||||
"read:confluence-space.summary%20"
|
||||
"read:confluence-props%20"
|
||||
"read:confluence-content.all%20"
|
||||
"read:confluence-content.summary%20"
|
||||
"read:confluence-content.permission%20"
|
||||
"read:confluence-user%20"
|
||||
"read:confluence-groups%20"
|
||||
"readonly:content.attachment:confluence%20"
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
|
||||
"offline_access"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# eventually for Confluence Data Center
|
||||
# oauth_url = (
|
||||
# f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# f"&redirect_uri={redirectme_uri}"
|
||||
# )
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code
|
||||
|
||||
url = (
|
||||
"https://auth.atlassian.com/authorize"
|
||||
f"?audience=api.atlassian.com"
|
||||
f"&client_id={cls.CLIENT_ID}"
|
||||
f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&state={state}"
|
||||
"&response_type=code"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = ConfluenceCloudOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
@classmethod
|
||||
def generate_finalize_url(cls, credential_id: int) -> str:
|
||||
return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}"
|
||||
|
||||
|
||||
@router.post("/connector/confluence/callback")
|
||||
def confluence_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Handles the backend logic for the frontend page that the user is redirected to
|
||||
after visiting the oauth authorization url."""
|
||||
|
||||
if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Confluence Cloud client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = ConfluenceCloudOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
ConfluenceCloudOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
"client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
token_response: ConfluenceCloudOAuth.TokenResponse | None = None
|
||||
|
||||
try:
|
||||
token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json(
|
||||
response.text
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError(
|
||||
"Confluence Cloud OAuth failed during code/token exchange."
|
||||
)
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={
|
||||
"confluence_access_token": token_response.access_token,
|
||||
"confluence_refresh_token": token_response.refresh_token,
|
||||
"created_at": now.isoformat(),
|
||||
"expires_at": expires_at.isoformat(),
|
||||
"expires_in": token_response.expires_in,
|
||||
"scope": token_response.scope,
|
||||
},
|
||||
admin_public=True,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
name="Confluence Cloud OAuth",
|
||||
)
|
||||
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth completed successfully.",
|
||||
"finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id),
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.get("/connector/confluence/accessible-resources")
|
||||
def confluence_oauth_accessible_resources(
|
||||
credential_id: int,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Atlassian's API is weird and does not supply us with enough info to be in a
|
||||
usable state after authorizing. All API's require a cloud id. We have to list
|
||||
the accessible resources/sites and let the user choose which site to use."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(400, f"Credential {credential_id} not found.")
|
||||
|
||||
credential_dict = credential.credential_json
|
||||
access_token = credential_dict["confluence_access_token"]
|
||||
|
||||
try:
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.get(
|
||||
ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL,
|
||||
headers={
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
accessible_resources_data = response.json()
|
||||
|
||||
# Validate the list of AccessibleResources
|
||||
try:
|
||||
accessible_resources = [
|
||||
ConfluenceCloudOAuth.AccessibleResources(**resource)
|
||||
for resource in accessible_resources_data
|
||||
]
|
||||
except ValidationError as e:
|
||||
raise RuntimeError(f"Failed to parse accessible resources: {e}")
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud get accessible resources completed successfully.",
|
||||
"accessible_resources": [
|
||||
resource.model_dump() for resource in accessible_resources
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@router.post("/connector/confluence/finalize")
|
||||
def confluence_oauth_finalize(
|
||||
credential_id: int,
|
||||
cloud_id: str,
|
||||
cloud_name: str,
|
||||
cloud_url: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Saves the info for the selected cloud site to the credential.
|
||||
This is the final step in the confluence oauth flow where after the traditional
|
||||
OAuth process, the user has to select a site to associate with the credentials.
|
||||
After this, the credential is usable."""
|
||||
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if not credential:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
|
||||
)
|
||||
|
||||
new_credential_json: dict[str, Any] = dict(credential.credential_json)
|
||||
new_credential_json["cloud_id"] = cloud_id
|
||||
new_credential_json["cloud_name"] = cloud_name
|
||||
new_credential_json["wiki_base"] = cloud_url
|
||||
|
||||
try:
|
||||
update_credential_json(credential_id, new_credential_json, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Confluence Cloud OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Confluence Cloud OAuth finalized successfully.",
|
||||
"redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence",
|
||||
}
|
||||
)
|
||||
229
backend/ee/onyx/server/oauth/google_drive.py
Normal file
229
backend/ee/onyx/server/oauth/google_drive.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.google_utils.google_auth import get_google_oauth_creds
|
||||
from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_DICT_TOKEN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_PRIMARY_ADMIN_KEY,
|
||||
)
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
GoogleOAuthAuthenticationMethod,
|
||||
)
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class GoogleDriveOAuth:
|
||||
# https://developers.google.com/identity/protocols/oauth2
|
||||
# https://developers.google.com/identity/protocols/oauth2/web-server
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/google-drive
|
||||
# TODO: Merge with or use google_utils.GOOGLE_SCOPES
|
||||
SCOPE = (
|
||||
"https://www.googleapis.com/auth/drive.readonly%20"
|
||||
"https://www.googleapis.com/auth/drive.metadata.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.user.readonly%20"
|
||||
"https://www.googleapis.com/auth/admin.directory.group.readonly"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# without prompt=consent, a refresh token is only issued the first time the user approves
|
||||
url = (
|
||||
f"https://accounts.google.com/o/oauth2/v2/auth"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
"&response_type=code"
|
||||
f"&scope={cls.SCOPE}"
|
||||
"&access_type=offline"
|
||||
f"&state={state}"
|
||||
"&prompt=consent"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = GoogleDriveOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/google-drive/callback")
|
||||
def handle_google_drive_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Google Drive client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = GoogleDriveOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = GoogleDriveOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
GoogleDriveOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": GoogleDriveOAuth.CLIENT_ID,
|
||||
"client_secret": GoogleDriveOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
"grant_type": "authorization_code",
|
||||
},
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
|
||||
authorization_response: dict[str, Any] = response.json()
|
||||
|
||||
# the connector wants us to store the json in its authorized_user_info format
|
||||
# returned from OAuthCredentials.get_authorized_user_info().
|
||||
# So refresh immediately via get_google_oauth_creds with the params filled in
|
||||
# from fields in authorization_response to get the json we need
|
||||
authorized_user_info = {}
|
||||
authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID
|
||||
authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET
|
||||
authorized_user_info["refresh_token"] = authorization_response["refresh_token"]
|
||||
|
||||
token_json_str = json.dumps(authorized_user_info)
|
||||
oauth_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE
|
||||
)
|
||||
if not oauth_creds:
|
||||
raise RuntimeError("get_google_oauth_creds returned None.")
|
||||
|
||||
# save off the credentials
|
||||
oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds)
|
||||
|
||||
credential_dict: dict[str, str] = {}
|
||||
credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str
|
||||
credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email
|
||||
credential_dict[
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD
|
||||
] = GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json=credential_dict,
|
||||
admin_public=True,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
name="OAuth (interactive)",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Google Drive OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Google Drive OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
197
backend/ee/onyx/server/oauth/slack.py
Normal file
197
backend/ee/onyx/server/oauth/slack.py
Normal file
@@ -0,0 +1,197 @@
|
||||
import base64
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
from ee.onyx.server.oauth.api_router import router
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.models import CredentialBase
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
@classmethod
|
||||
def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={redirect_uri}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
if not DEV_MODE:
|
||||
redirect_uri = SlackOAuth.REDIRECT_URI
|
||||
else:
|
||||
redirect_uri = SlackOAuth.DEV_REDIRECT_URI
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"finalize_url": None,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
}
|
||||
)
|
||||
@@ -2,6 +2,7 @@ import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -21,8 +22,10 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
@@ -35,6 +38,8 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
@@ -107,6 +112,17 @@ def get_user_chat_sessions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session, limit=0
|
||||
@@ -122,6 +138,7 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
@@ -141,6 +158,12 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
@@ -157,11 +180,16 @@ def get_chat_session_history(
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
minimal_chat_sessions: list[ChatSessionMinimal] = []
|
||||
|
||||
for chat_session in page_of_chat_sessions:
|
||||
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
minimal_chat_sessions.append(minimal_chat_session)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
items=minimal_chat_sessions,
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
@@ -172,6 +200,12 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -193,6 +227,9 @@ def get_chat_session_admin(
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -203,6 +240,12 @@ def get_query_history_as_csv(
|
||||
end: datetime | None = None,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
@@ -213,6 +256,9 @@ def get_query_history_as_csv(
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from ee.onyx.configs.app_configs import STRIPE_PRICE_ID
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -41,7 +42,9 @@ def fetch_tenant_stripe_information(tenant_id: str) -> dict:
|
||||
return response.json()
|
||||
|
||||
|
||||
def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
def fetch_billing_information(
|
||||
tenant_id: str,
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
@@ -52,8 +55,19 @@ def fetch_billing_information(tenant_id: str) -> BillingInformation:
|
||||
params = {"tenant_id": tenant_id}
|
||||
response = requests.get(url, headers=headers, params=params)
|
||||
response.raise_for_status()
|
||||
billing_info = BillingInformation(**response.json())
|
||||
return billing_info
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# Check if the response indicates no subscription
|
||||
if (
|
||||
isinstance(response_data, dict)
|
||||
and "subscribed" in response_data
|
||||
and not response_data["subscribed"]
|
||||
):
|
||||
return SubscriptionStatusResponse(**response_data)
|
||||
|
||||
# Otherwise, parse as BillingInformation
|
||||
return BillingInformation(**response_data)
|
||||
|
||||
|
||||
def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription:
|
||||
|
||||
@@ -200,25 +200,6 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if ANTHROPIC_DEFAULT_API_KEY:
|
||||
anthropic_provider = LLMProviderUpsertRequest(
|
||||
name="Anthropic",
|
||||
@@ -227,6 +208,7 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
default_model_name="claude-3-7-sonnet-20250219",
|
||||
fast_default_model_name="claude-3-5-sonnet-20241022",
|
||||
model_names=ANTHROPIC_MODEL_NAMES,
|
||||
display_model_names=["claude-3-5-sonnet-20241022"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(anthropic_provider, db_session)
|
||||
@@ -238,6 +220,26 @@ def configure_default_api_keys(db_session: Session) -> None:
|
||||
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
|
||||
)
|
||||
|
||||
if OPENAI_DEFAULT_API_KEY:
|
||||
open_provider = LLMProviderUpsertRequest(
|
||||
name="OpenAI",
|
||||
provider=OPENAI_PROVIDER_NAME,
|
||||
api_key=OPENAI_DEFAULT_API_KEY,
|
||||
default_model_name="gpt-4o",
|
||||
fast_default_model_name="gpt-4o-mini",
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
|
||||
)
|
||||
try:
|
||||
full_provider = upsert_llm_provider(open_provider, db_session)
|
||||
update_default_provider(full_provider.id, db_session)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to configure OpenAI provider: {e}")
|
||||
else:
|
||||
logger.error(
|
||||
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
|
||||
)
|
||||
|
||||
if COHERE_DEFAULT_API_KEY:
|
||||
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
|
||||
provider_type=EmbeddingProvider.COHERE,
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
@@ -35,8 +36,7 @@ class ApiKeyDescriptor(BaseModel):
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
if not MULTI_TENANT or not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
|
||||
@@ -2,6 +2,8 @@ import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from email.utils import make_msgid
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -13,6 +15,7 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
@@ -150,8 +153,9 @@ def send_email(
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
@@ -173,7 +177,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
@@ -239,13 +243,13 @@ def send_user_email_invite(
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
tenant_id: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
|
||||
@@ -214,7 +214,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
@@ -411,7 +411,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"refresh_token": refresh_token,
|
||||
}
|
||||
|
||||
user: User
|
||||
user: User | None = None
|
||||
|
||||
try:
|
||||
# Attempt to get user by OAuth account
|
||||
@@ -420,15 +420,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.get_by_email(account_email)
|
||||
user = await self.user_db.get_by_email(account_email)
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
# Make sure user is not None before adding OAuth account
|
||||
if user is not None:
|
||||
user = await self.user_db.add_oauth_account(
|
||||
user, oauth_account_dict
|
||||
)
|
||||
else:
|
||||
# This shouldn't happen since get_by_email would raise UserNotExists
|
||||
# but adding as a safeguard
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
# If user not found by OAuth account or email, create a new user
|
||||
except exceptions.UserNotExists:
|
||||
password = self.password_helper.generate()
|
||||
user_dict = {
|
||||
@@ -439,26 +444,36 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
|
||||
# Explicitly set the Postgres schema for this session to ensure
|
||||
# OAuth account creation happens in the correct tenant schema
|
||||
|
||||
# Add OAuth account
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
# Add OAuth account only if user creation was successful
|
||||
if user is not None:
|
||||
await self.user_db.add_oauth_account(user, oauth_account_dict)
|
||||
await self.on_after_register(user, request)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to create user account"
|
||||
)
|
||||
|
||||
else:
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
# User exists, update OAuth account if needed
|
||||
if user is not None: # Add explicit check
|
||||
for existing_oauth_account in user.oauth_accounts:
|
||||
if (
|
||||
existing_oauth_account.account_id == account_id
|
||||
and existing_oauth_account.oauth_name == oauth_name
|
||||
):
|
||||
user = await self.user_db.update_oauth_account(
|
||||
user,
|
||||
# NOTE: OAuthAccount DOES implement the OAuthAccountProtocol
|
||||
# but the type checker doesn't know that :(
|
||||
existing_oauth_account, # type: ignore
|
||||
oauth_account_dict,
|
||||
)
|
||||
|
||||
# Ensure user is not None before proceeding
|
||||
if user is None:
|
||||
raise HTTPException(
|
||||
status_code=500, detail="Failed to authenticate or create user"
|
||||
)
|
||||
|
||||
# NOTE: Most IdPs have very short expiry times, and we don't want to force the user to
|
||||
# re-authenticate that frequently, so by default this is disabled
|
||||
@@ -553,7 +568,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
|
||||
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
@@ -131,9 +132,9 @@ def on_task_postrun(
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_deletion_status(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str,
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str,
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
|
||||
@@ -109,9 +109,7 @@ def revoke_tasks_blocking_deletion(
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -224,7 +222,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
@@ -345,7 +343,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
tenant_id: str, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -500,7 +498,7 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -540,7 +538,7 @@ def validate_connector_deletion_fences(
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
|
||||
@@ -221,7 +221,7 @@ def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -320,7 +320,7 @@ def try_creating_permissions_sync_task(
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
@@ -410,7 +410,6 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -510,7 +509,7 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
@@ -585,7 +584,7 @@ def update_external_document_permissions_task(
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -632,7 +631,7 @@ def validate_permission_sync_fences(
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
@@ -842,7 +841,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
|
||||
@@ -123,7 +123,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
@@ -220,7 +220,7 @@ def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -306,7 +306,7 @@ def try_creating_external_group_sync_task(
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
External group sync task for a given connector credential pair
|
||||
@@ -392,7 +392,6 @@ def connector_external_group_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -424,7 +423,7 @@ def connector_external_group_sync_generator_task(
|
||||
)
|
||||
external_user_groups: list[ExternalUserGroup] = []
|
||||
try:
|
||||
external_user_groups = ext_group_sync_func(cc_pair)
|
||||
external_user_groups = ext_group_sync_func(tenant_id, cc_pair)
|
||||
except ConnectorValidationError as e:
|
||||
msg = f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}"
|
||||
update_connector_credential_pair(
|
||||
@@ -494,7 +493,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
@@ -526,7 +525,7 @@ def validate_external_group_sync_fences(
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
|
||||
@@ -182,7 +182,7 @@ class SimpleJobResult:
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str | None
|
||||
tenant_id: str
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
@@ -210,7 +210,7 @@ class ConnectorIndexingLogBuilder:
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
@@ -358,7 +358,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
@@ -598,7 +598,7 @@ def connector_indexing_task(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -890,7 +890,7 @@ def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
@@ -1170,7 +1170,7 @@ def connector_indexing_proxy_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase):
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
@@ -311,7 +311,7 @@ def validate_indexing_fence(
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
@@ -442,7 +442,7 @@ def try_creating_indexing_task(
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class Metric(BaseModel):
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
def emit(self, tenant_id: str) -> None:
|
||||
# Convert value to appropriate type based on the input value
|
||||
bool_value = None
|
||||
float_value = None
|
||||
@@ -656,7 +656,7 @@ def build_job_id(
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
@@ -864,7 +864,7 @@ def cloud_monitor_celery_queues(
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
|
||||
@@ -114,7 +114,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -211,7 +211,7 @@ def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
@@ -333,7 +333,7 @@ def connector_pruning_generator_task(
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -521,7 +521,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -567,7 +567,7 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -615,7 +615,7 @@ def validate_pruning_fences(
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
|
||||
@@ -32,7 +32,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self.index.delete_single(
|
||||
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
|
||||
@@ -76,7 +76,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
@@ -297,7 +297,8 @@ def cloud_beat_task_generator(
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] | list[None] = []
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -325,6 +326,8 @@ def cloud_beat_task_generator(
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -344,6 +347,7 @@ def cloud_beat_task_generator(
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
@@ -208,7 +208,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
|
||||
@@ -284,7 +284,7 @@ def try_generate_document_set_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -361,7 +361,7 @@ def try_generate_user_group_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -448,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -523,9 +523,7 @@ def monitor_document_set_taskset(
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, *, tenant_id: str | None
|
||||
) -> bool:
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
@@ -11,10 +11,27 @@ def emit_background_error(
|
||||
"""Currently just saves a row in the background_errors table.
|
||||
|
||||
In the future, could create notifications based on the severity."""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
try:
|
||||
error_message = ""
|
||||
|
||||
# try to write to the db, but handle IntegrityError specifically
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_background_error(db_session, message, cc_pair_id)
|
||||
except IntegrityError as e:
|
||||
# Log an error if the cc_pair_id was deleted or any other exception occurs
|
||||
error_message = f"Failed to create background error: {str(e)}. Original message: {message}"
|
||||
except IntegrityError as e:
|
||||
# Log an error if the cc_pair_id was deleted or any other exception occurs
|
||||
error_message = (
|
||||
f"Failed to create background error: {str(e)}. Original message: {message}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not error_message:
|
||||
return
|
||||
|
||||
# if we get here from an IntegrityError, try to write the error message to the db
|
||||
# we need a new session because the first session is now invalid
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_background_error(db_session, error_message, None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -16,7 +16,7 @@ from typing import Optional
|
||||
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -55,6 +55,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -67,7 +68,6 @@ def _get_connector_runner(
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
@@ -86,7 +86,6 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
@@ -94,10 +93,11 @@ def _get_connector_runner(
|
||||
runnable_connector.validate_connector_settings()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unable to instantiate connector due to {e}")
|
||||
|
||||
logger.exception("Unable to instantiate connector.")
|
||||
# since we failed to even instantiate the connector, we pause the CCPair since
|
||||
# it will never succeed. Sometimes there are cases where the connector will
|
||||
# it will never succeed
|
||||
|
||||
# Sometimes there are cases where the connector will
|
||||
# intermittently fail to initialize in which case we should pass in
|
||||
# leave_connector_active=True to allow it to continue.
|
||||
# For example, if there is nightly maintenance on a Confluence Server instance,
|
||||
@@ -241,7 +241,7 @@ def _check_failure_threshold(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -388,7 +388,6 @@ def _run_indexing(
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
@@ -681,7 +680,7 @@ def _run_indexing(
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
@@ -701,7 +700,7 @@ def run_indexing_entrypoint(
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
if MULTI_TENANT:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
@@ -29,6 +30,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
|
||||
@@ -213,6 +213,12 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class QueryHistoryType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ANONYMIZED = "anonymized"
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
|
||||
@@ -11,17 +11,20 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.onyx_confluence import build_confluence_client
|
||||
from onyx.connectors.confluence.onyx_confluence import attachment_to_content
|
||||
from onyx.connectors.confluence.onyx_confluence import (
|
||||
extract_text_from_confluence_html,
|
||||
)
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import attachment_to_content
|
||||
from onyx.connectors.confluence.utils import build_confluence_document_id
|
||||
from onyx.connectors.confluence.utils import datetime_from_string
|
||||
from onyx.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedError
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
@@ -83,7 +86,9 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
|
||||
)
|
||||
|
||||
|
||||
class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class ConfluenceConnector(
|
||||
LoadConnector, PollConnector, SlimConnector, CredentialsConnector
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
wiki_base: str,
|
||||
@@ -102,7 +107,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self.is_cloud = is_cloud
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
@@ -137,6 +141,19 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
|
||||
self.probe_kwargs = {
|
||||
"max_backoff_retries": 6,
|
||||
"max_backoff_seconds": 10,
|
||||
}
|
||||
|
||||
self.final_kwargs = {
|
||||
"max_backoff_retries": 10,
|
||||
"max_backoff_seconds": 60,
|
||||
}
|
||||
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
@@ -144,15 +161,22 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
raise ConnectorMissingCredentialError("Confluence")
|
||||
return self._confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
# see https://github.com/atlassian-api/atlassian-python-api/blob/master/atlassian/rest_client.py
|
||||
# for a list of other hidden constructor args
|
||||
self._confluence_client = build_confluence_client(
|
||||
credentials=credentials,
|
||||
is_cloud=self.is_cloud,
|
||||
wiki_base=self.wiki_base,
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
self.credentials_provider = credentials_provider
|
||||
|
||||
# raises exception if there's a problem
|
||||
confluence_client = OnyxConfluence(
|
||||
self.is_cloud, self.wiki_base, credentials_provider
|
||||
)
|
||||
return None
|
||||
confluence_client._probe_connection(**self.probe_kwargs)
|
||||
confluence_client._initialize_connection(**self.final_kwargs)
|
||||
|
||||
self._confluence_client = confluence_client
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
raise NotImplementedError("Use set_credentials_provider with this connector.")
|
||||
|
||||
def _construct_page_query(
|
||||
self,
|
||||
@@ -202,12 +226,17 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
return comment_string
|
||||
|
||||
def _convert_object_to_document(
|
||||
self, confluence_object: dict[str, Any]
|
||||
self,
|
||||
confluence_object: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Takes in a confluence object, extracts all metadata, and converts it into a document.
|
||||
If its a page, it extracts the text, adds the comments for the document text.
|
||||
If its an attachment, it just downloads the attachment and converts that into a document.
|
||||
|
||||
parent_content_id: if the object is an attachment, specifies the content id that
|
||||
the attachment is attached to
|
||||
"""
|
||||
# The url and the id are the same
|
||||
object_url = build_confluence_document_id(
|
||||
@@ -226,7 +255,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
|
||||
elif confluence_object["type"] == "attachment":
|
||||
object_text = attachment_to_content(
|
||||
confluence_client=self.confluence_client, attachment=confluence_object
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=confluence_object,
|
||||
parent_content_id=parent_content_id,
|
||||
)
|
||||
|
||||
if object_text is None:
|
||||
@@ -302,7 +333,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
cql=attachment_query,
|
||||
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
|
||||
):
|
||||
doc = self._convert_object_to_document(attachment)
|
||||
doc = self._convert_object_to_document(attachment, confluence_page_id)
|
||||
if doc is not None:
|
||||
doc_batch.append(doc)
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
|
||||
@@ -1,19 +1,37 @@
|
||||
import math
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
from urllib.parse import quote
|
||||
|
||||
import bs4
|
||||
from atlassian import Confluence # type:ignore
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
from requests import HTTPError
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -22,12 +40,14 @@ logger = setup_logger()
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-76433
|
||||
_PROBLEMATIC_EXPANSIONS = "body.storage.value"
|
||||
_REPLACEMENT_EXPANSIONS = "body.view.value"
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
class ConfluenceRateLimitError(Exception):
|
||||
pass
|
||||
@@ -43,124 +63,349 @@ class ConfluenceUser(BaseModel):
|
||||
type: str
|
||||
|
||||
|
||||
def _handle_http_error(e: HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
_DEFAULT_PAGINATION_LIMIT = 1000
|
||||
_MINIMUM_PAGINATION_LIMIT = 50
|
||||
|
||||
|
||||
class OnyxConfluence(Confluence):
|
||||
class OnyxConfluence:
|
||||
"""
|
||||
This is a custom Confluence class that overrides the default Confluence class to add a custom CQL method.
|
||||
This is a custom Confluence class that:
|
||||
|
||||
A. overrides the default Confluence class to add a custom CQL method.
|
||||
B.
|
||||
This is necessary because the default Confluence class does not properly support cql expansions.
|
||||
All methods are automatically wrapped with handle_confluence_rate_limit.
|
||||
"""
|
||||
|
||||
def __init__(self, url: str, *args: Any, **kwargs: Any) -> None:
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
CREDENTIAL_PREFIX = "connector:confluence:credential"
|
||||
CREDENTIAL_TTL = 300 # 5 min
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
is_cloud: bool,
|
||||
url: str,
|
||||
credentials_provider: CredentialsProviderInterface,
|
||||
) -> None:
|
||||
self._is_cloud = is_cloud
|
||||
self._url = url.rstrip("/")
|
||||
self._credentials_provider = credentials_provider
|
||||
|
||||
self.redis_client: Redis | None = None
|
||||
self.static_credentials: dict[str, Any] | None = None
|
||||
if self._credentials_provider.is_dynamic():
|
||||
self.redis_client = get_redis_client(
|
||||
tenant_id=credentials_provider.get_tenant_id()
|
||||
)
|
||||
else:
|
||||
self.static_credentials = self._credentials_provider.get_credentials()
|
||||
|
||||
self._confluence = Confluence(url)
|
||||
self.credential_key: str = (
|
||||
self.CREDENTIAL_PREFIX
|
||||
+ f":credential_{self._credentials_provider.get_provider_key()}"
|
||||
)
|
||||
|
||||
self._kwargs: Any = None
|
||||
|
||||
self.shared_base_kwargs = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
|
||||
def _renew_credentials(self) -> tuple[dict[str, Any], bool]:
|
||||
"""credential_json - the current json credentials
|
||||
Returns a tuple
|
||||
1. The up to date credentials
|
||||
2. True if the credentials were updated
|
||||
|
||||
This method is intended to be used within a distributed lock.
|
||||
Lock, call this, update credentials if the tokens were refreshed, then release
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
wrap it with handle_confluence_rate_limit.
|
||||
"""
|
||||
for attr_name in dir(self):
|
||||
if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
setattr(
|
||||
self,
|
||||
attr_name,
|
||||
handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# static credentials are preloaded, so no locking/redis required
|
||||
if self.static_credentials:
|
||||
return self.static_credentials, False
|
||||
|
||||
if not self.redis_client:
|
||||
raise RuntimeError("self.redis_client is None")
|
||||
|
||||
# dynamic credentials need locking
|
||||
# check redis first, then fallback to the DB
|
||||
credential_raw = self.redis_client.get(self.credential_key)
|
||||
if credential_raw is not None:
|
||||
credential_bytes = cast(bytes, credential_raw)
|
||||
credential_str = credential_bytes.decode("utf-8")
|
||||
credential_json: dict[str, Any] = json.loads(credential_str)
|
||||
else:
|
||||
credential_json = self._credentials_provider.get_credentials()
|
||||
|
||||
if "confluence_refresh_token" not in credential_json:
|
||||
# static credentials ... cache them permanently and return
|
||||
self.static_credentials = credential_json
|
||||
return credential_json, False
|
||||
|
||||
# check if we should refresh tokens. we're deciding to refresh halfway
|
||||
# to expiration
|
||||
now = datetime.now(timezone.utc)
|
||||
created_at = datetime.fromisoformat(credential_json["created_at"])
|
||||
expires_in: int = credential_json["expires_in"]
|
||||
renew_at = created_at + timedelta(seconds=expires_in // 2)
|
||||
if now <= renew_at:
|
||||
# cached/current credentials are reasonably up to date
|
||||
return credential_json, False
|
||||
|
||||
# we need to refresh
|
||||
logger.info("Renewing Confluence Cloud credentials...")
|
||||
new_credentials = confluence_refresh_tokens(
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_ID,
|
||||
OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET,
|
||||
credential_json["cloud_id"],
|
||||
credential_json["confluence_refresh_token"],
|
||||
)
|
||||
|
||||
# store the new credentials to redis and to the db thru the provider
|
||||
# redis: we use a 5 min TTL because we are given a 10 minute grace period
|
||||
# when keys are rotated. it's easier to expire the cached credentials
|
||||
# reasonably frequently rather than trying to handle strong synchronization
|
||||
# between the db and redis everywhere the credentials might be updated
|
||||
new_credential_str = json.dumps(new_credentials)
|
||||
self.redis_client.set(
|
||||
self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL
|
||||
)
|
||||
self._credentials_provider.set_credentials(new_credentials)
|
||||
|
||||
return new_credentials, True
|
||||
|
||||
@staticmethod
|
||||
def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]:
|
||||
oauth2_dict: dict[str, Any] = {}
|
||||
if "confluence_refresh_token" in credentials:
|
||||
oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
oauth2_dict["token"] = {}
|
||||
oauth2_dict["token"]["access_token"] = credentials[
|
||||
"confluence_access_token"
|
||||
]
|
||||
return oauth2_dict
|
||||
|
||||
def _probe_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Probing Confluence with OAuth Access Token.")
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(
|
||||
credentials
|
||||
)
|
||||
url = (
|
||||
f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
)
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url, oauth2=oauth2_dict, **merged_kwargs
|
||||
)
|
||||
else:
|
||||
logger.info("Probing Confluence with Personal Access Token.")
|
||||
url = self._url
|
||||
if self._is_cloud:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
else:
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
url=url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**merged_kwargs,
|
||||
)
|
||||
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {url}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
)
|
||||
|
||||
logger.info("Confluence probe succeeded.")
|
||||
|
||||
def _initialize_connection(
|
||||
self,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Called externally to init the connection in a thread safe manner."""
|
||||
merged_kwargs = {**self.shared_base_kwargs, **kwargs}
|
||||
with self._credentials_provider:
|
||||
credentials, _ = self._renew_credentials()
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **merged_kwargs
|
||||
)
|
||||
self._kwargs = merged_kwargs
|
||||
|
||||
def _initialize_connection_helper(
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Confluence:
|
||||
"""Called internally to init the connection. Distributed locking
|
||||
to prevent multiple threads from modifying the credentials
|
||||
must be handled around this function."""
|
||||
|
||||
confluence = None
|
||||
|
||||
# probe connection with direct client, no retries
|
||||
if "confluence_refresh_token" in credentials:
|
||||
logger.info("Connecting to Confluence Cloud with OAuth Access Token.")
|
||||
|
||||
oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials)
|
||||
url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}"
|
||||
confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs)
|
||||
else:
|
||||
logger.info("Connecting to Confluence with Personal Access Token.")
|
||||
if self._is_cloud:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
username=credentials["confluence_username"],
|
||||
password=credentials["confluence_access_token"],
|
||||
**kwargs,
|
||||
)
|
||||
else:
|
||||
confluence = Confluence(
|
||||
url=self._url,
|
||||
token=credentials["confluence_access_token"],
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return confluence
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def _make_rate_limited_confluence_method(
|
||||
self, name: str, credential_provider: CredentialsProviderInterface | None
|
||||
) -> Callable[..., Any]:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
try:
|
||||
if credential_provider:
|
||||
with credential_provider:
|
||||
credentials, renewed = self._renew_credentials()
|
||||
if renewed:
|
||||
self._confluence = self._initialize_connection_helper(
|
||||
credentials, **self._kwargs
|
||||
)
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
return attr(*args, **kwargs)
|
||||
else:
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return wrapped_call
|
||||
|
||||
# def _wrap_methods(self) -> None:
|
||||
# """
|
||||
# For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
# wrap it with handle_confluence_rate_limit.
|
||||
# """
|
||||
# for attr_name in dir(self):
|
||||
# if callable(getattr(self, attr_name)) and not attr_name.startswith("_"):
|
||||
# setattr(
|
||||
# self,
|
||||
# attr_name,
|
||||
# handle_confluence_rate_limit(getattr(self, attr_name)),
|
||||
# )
|
||||
|
||||
# def _ensure_token_valid(self) -> None:
|
||||
# if self._token_is_expired():
|
||||
# self._refresh_token()
|
||||
# # Re-init the Confluence client with the originally stored args
|
||||
# self._confluence = Confluence(self._url, *self._args, **self._kwargs)
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
"""Dynamically intercept attribute/method access."""
|
||||
attr = getattr(self._confluence, name, None)
|
||||
if attr is None:
|
||||
# The underlying Confluence client doesn't have this attribute
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{name}'"
|
||||
)
|
||||
|
||||
# If it's not a method, just return it after ensuring token validity
|
||||
if not callable(attr):
|
||||
return attr
|
||||
|
||||
# skip methods that start with "_"
|
||||
if name.startswith("_"):
|
||||
return attr
|
||||
|
||||
# wrap the method with our retry handler
|
||||
rate_limited_method: Callable[
|
||||
..., Any
|
||||
] = self._make_rate_limited_confluence_method(name, self._credentials_provider)
|
||||
|
||||
def wrapped_method(*args: Any, **kwargs: Any) -> Any:
|
||||
return rate_limited_method(*args, **kwargs)
|
||||
|
||||
return wrapped_method
|
||||
|
||||
def _paginate_url(
|
||||
self, url_suffix: str, limit: int | None = None, auto_paginate: bool = False
|
||||
@@ -507,63 +752,212 @@ class OnyxConfluence(Confluence):
|
||||
return response
|
||||
|
||||
|
||||
def _validate_connector_configuration(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> None:
|
||||
# test connection with direct client, no retries
|
||||
confluence_client_with_minimal_retries = Confluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
url=wiki_base.rstrip("/"),
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=6,
|
||||
max_backoff_seconds=10,
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: OnyxConfluence, user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
|
||||
# why are we using session.get here? we probably won't retry these ... is that ok?
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
return extracted_text
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
"Check your credentials and wiki_base and make sure "
|
||||
"is_cloud is set correctly."
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
def build_confluence_client(
|
||||
credentials: dict[str, Any],
|
||||
is_cloud: bool,
|
||||
wiki_base: str,
|
||||
) -> OnyxConfluence:
|
||||
try:
|
||||
_validate_connector_configuration(
|
||||
credentials=credentials,
|
||||
is_cloud=is_cloud,
|
||||
wiki_base=wiki_base,
|
||||
)
|
||||
except Exception as e:
|
||||
raise ConnectorValidationError(str(e))
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return OnyxConfluence(
|
||||
api_version="cloud" if is_cloud else "latest",
|
||||
# Remove trailing slash from wiki_base if present
|
||||
url=wiki_base.rstrip("/"),
|
||||
# passing in username causes issues for Confluence data center
|
||||
username=credentials["confluence_username"] if is_cloud else None,
|
||||
password=credentials["confluence_access_token"] if is_cloud else None,
|
||||
token=credentials["confluence_access_token"] if not is_cloud else None,
|
||||
backoff_and_retry=True,
|
||||
max_backoff_retries=10,
|
||||
max_backoff_seconds=60,
|
||||
cloud=is_cloud,
|
||||
)
|
||||
return format_document_soup(soup)
|
||||
|
||||
@@ -1,185 +1,38 @@
|
||||
import io
|
||||
import math
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import bs4
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
pass
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_USER_EMAIL_CACHE: dict[str, str | None] = {}
|
||||
CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower()
|
||||
|
||||
|
||||
def get_user_email_from_username__server(
|
||||
confluence_client: "OnyxConfluence", user_name: str
|
||||
) -> str | None:
|
||||
global _USER_EMAIL_CACHE
|
||||
if _USER_EMAIL_CACHE.get(user_name) is None:
|
||||
try:
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
# For now, we'll just return None and log a warning. This means
|
||||
# we will keep retrying to get the email every group sync.
|
||||
email = None
|
||||
# We may want to just return a string that indicates failure so we dont
|
||||
# keep retrying
|
||||
# email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
_USER_NOT_FOUND = "Unknown Confluence User"
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {}
|
||||
|
||||
|
||||
def _get_user(confluence_client: "OnyxConfluence", user_id: str) -> str:
|
||||
"""Get Confluence Display Name based on the account-id or userkey value
|
||||
|
||||
Args:
|
||||
user_id (str): The user id (i.e: the account-id or userkey)
|
||||
confluence_client (Confluence): The Confluence Client
|
||||
|
||||
Returns:
|
||||
str: The User Display Name. 'Unknown User' if the user is deactivated or not found
|
||||
"""
|
||||
global _USER_ID_TO_DISPLAY_NAME_CACHE
|
||||
if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_userkey(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
if not found_display_name:
|
||||
try:
|
||||
result = confluence_client.get_user_details_by_accountid(user_id)
|
||||
found_display_name = result.get("displayName")
|
||||
except Exception:
|
||||
found_display_name = None
|
||||
|
||||
_USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name
|
||||
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: "OnyxConfluence",
|
||||
confluence_object: dict[str, Any],
|
||||
fetched_titles: set[str],
|
||||
) -> str:
|
||||
"""Parse a Confluence html page and replace the 'user Id' by the real
|
||||
User Display Name
|
||||
|
||||
Args:
|
||||
confluence_object (dict): The confluence object as a dict
|
||||
confluence_client (Confluence): Confluence client
|
||||
fetched_titles (set[str]): The titles of the pages that have already been fetched
|
||||
Returns:
|
||||
str: loaded and formated Confluence page
|
||||
"""
|
||||
body = confluence_object["body"]
|
||||
object_html = body.get("storage", body.get("view", {})).get("value")
|
||||
|
||||
soup = bs4.BeautifulSoup(object_html, "html.parser")
|
||||
for user in soup.findAll("ri:user"):
|
||||
user_id = (
|
||||
user.attrs["ri:account-id"]
|
||||
if "ri:account-id" in user.attrs
|
||||
else user.get("ri:userkey")
|
||||
)
|
||||
if not user_id:
|
||||
logger.warning(
|
||||
"ri:userkey not found in ri:user element. " f"Found attrs: {user.attrs}"
|
||||
)
|
||||
continue
|
||||
# Include @ sign for tagging, more clear for LLM
|
||||
user.replaceWith("@" + _get_user(confluence_client, user_id))
|
||||
|
||||
for html_page_reference in soup.findAll("ac:structured-macro"):
|
||||
# Here, we only want to process page within page macros
|
||||
if html_page_reference.attrs.get("ac:name") != "include":
|
||||
continue
|
||||
|
||||
page_data = html_page_reference.find("ri:page")
|
||||
if not page_data:
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because because page data is missing"
|
||||
)
|
||||
continue
|
||||
|
||||
page_title = page_data.attrs.get("ri:content-title")
|
||||
if not page_title:
|
||||
# only fetch pages that have a title
|
||||
logger.warning(
|
||||
f"Skipping retrieval of {html_page_reference} because it has no title"
|
||||
)
|
||||
continue
|
||||
|
||||
if page_title in fetched_titles:
|
||||
# prevent recursive fetching of pages
|
||||
logger.debug(f"Skipping {page_title} because it has already been fetched")
|
||||
continue
|
||||
|
||||
fetched_titles.add(page_title)
|
||||
|
||||
# Wrap this in a try-except because there are some pages that might not exist
|
||||
try:
|
||||
page_query = f"type=page and title='{quote(page_title)}'"
|
||||
|
||||
page_contents: dict[str, Any] | None = None
|
||||
# Confluence enforces title uniqueness, so we should only get one result here
|
||||
for page in confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand="body.storage.value",
|
||||
limit=1,
|
||||
):
|
||||
page_contents = page
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error getting page contents for object {confluence_object}: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not page_contents:
|
||||
continue
|
||||
|
||||
text_from_page = extract_text_from_confluence_html(
|
||||
confluence_client=confluence_client,
|
||||
confluence_object=page_contents,
|
||||
fetched_titles=fetched_titles,
|
||||
)
|
||||
|
||||
html_page_reference.replaceWith(text_from_page)
|
||||
|
||||
for html_link_body in soup.findAll("ac:link-body"):
|
||||
# This extracts the text from inline links in the page so they can be
|
||||
# represented in the document text as plain text
|
||||
try:
|
||||
text_from_link = html_link_body.text
|
||||
html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error processing ac:link-body: {e}")
|
||||
|
||||
return format_document_soup(soup)
|
||||
class TokenResponse(BaseModel):
|
||||
access_token: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
scope: str
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
@@ -193,49 +46,6 @@ def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
|
||||
def build_confluence_document_id(
|
||||
base_url: str, content_url: str, is_cloud: bool
|
||||
) -> str:
|
||||
@@ -284,6 +94,137 @@ def datetime_from_string(datetime_string: str) -> datetime:
|
||||
return datetime_object
|
||||
|
||||
|
||||
def confluence_refresh_tokens(
|
||||
client_id: str, client_secret: str, cloud_id: str, refresh_token: str
|
||||
) -> dict[str, Any]:
|
||||
# rotate the refresh and access token
|
||||
# Note that access tokens are only good for an hour in confluence cloud,
|
||||
# so we're going to have problems if the connector runs for longer
|
||||
# https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair
|
||||
response = requests.post(
|
||||
CONFLUENCE_OAUTH_TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
"refresh_token": refresh_token,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
token_response = TokenResponse.model_validate_json(response.text)
|
||||
except Exception:
|
||||
raise RuntimeError("Confluence Cloud token refresh failed.")
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
expires_at = now + timedelta(seconds=token_response.expires_in)
|
||||
|
||||
new_credentials: dict[str, Any] = {}
|
||||
new_credentials["confluence_access_token"] = token_response.access_token
|
||||
new_credentials["confluence_refresh_token"] = token_response.refresh_token
|
||||
new_credentials["created_at"] = now.isoformat()
|
||||
new_credentials["expires_at"] = expires_at.isoformat()
|
||||
new_credentials["expires_in"] = token_response.expires_in
|
||||
new_credentials["scope"] = token_response.scope
|
||||
new_credentials["cloud_id"] = cloud_id
|
||||
return new_credentials
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
# https://developer.atlassian.com/cloud/confluence/rate-limiting/
|
||||
# this uses the native rate limiting option provided by the
|
||||
# confluence client and otherwise applies a simpler set of error handling
|
||||
def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
def wrapped_call(*args: list[Any], **kwargs: Any) -> Any:
|
||||
MAX_RETRIES = 5
|
||||
|
||||
TIMEOUT = 600
|
||||
timeout_at = time.monotonic() + TIMEOUT
|
||||
|
||||
for attempt in range(MAX_RETRIES):
|
||||
if time.monotonic() > timeout_at:
|
||||
raise TimeoutError(
|
||||
f"Confluence call attempts took longer than {TIMEOUT} seconds."
|
||||
)
|
||||
|
||||
try:
|
||||
# we're relying more on the client to rate limit itself
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. "
|
||||
f"Retrying in {delay_until} seconds..."
|
||||
)
|
||||
while time.monotonic() < delay_until:
|
||||
# in the future, check a signal here to exit
|
||||
time.sleep(1)
|
||||
except AttributeError as e:
|
||||
# Some error within the Confluence library, unclear why it fails.
|
||||
# Users reported it to be intermittent, so just retry
|
||||
if attempt == MAX_RETRIES - 1:
|
||||
raise e
|
||||
|
||||
logger.exception(
|
||||
"Confluence Client raised an AttributeError. Retrying..."
|
||||
)
|
||||
time.sleep(5)
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
BACKOFF = 2
|
||||
|
||||
# Check if the response or headers are None to avoid potential AttributeError
|
||||
if e.response is None or e.response.headers is None:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
):
|
||||
raise e
|
||||
|
||||
retry_after = None
|
||||
|
||||
retry_after_header = e.response.headers.get("Retry-After")
|
||||
if retry_after_header is not None:
|
||||
try:
|
||||
retry_after = int(retry_after_header)
|
||||
if retry_after > MAX_DELAY:
|
||||
logger.warning(
|
||||
f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..."
|
||||
)
|
||||
retry_after = MAX_DELAY
|
||||
if retry_after < MIN_DELAY:
|
||||
retry_after = MIN_DELAY
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
if retry_after is not None:
|
||||
logger.warning(
|
||||
f"Rate limiting with retry header. Retrying after {retry_after} seconds..."
|
||||
)
|
||||
delay = retry_after
|
||||
else:
|
||||
logger.warning(
|
||||
"Rate limiting without retry header. Retrying with exponential backoff..."
|
||||
)
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
|
||||
delay_until = math.ceil(time.monotonic() + delay)
|
||||
return delay_until
|
||||
|
||||
|
||||
def get_single_param_from_url(url: str, param: str) -> str | None:
|
||||
"""Get a parameter from a url"""
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
135
backend/onyx/connectors/credentials_provider.py
Normal file
135
backend/onyx/connectors/credentials_provider.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import uuid
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy import select
|
||||
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import Credential
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
|
||||
class OnyxDBCredentialsProvider(
|
||||
CredentialsProviderInterface["OnyxDBCredentialsProvider"]
|
||||
):
|
||||
"""Implementation to allow the connector to callback and update credentials in the db.
|
||||
Required in cases where credentials can rotate while the connector is running.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 900 # TTL of the lock
|
||||
|
||||
def __init__(self, tenant_id: str, connector_name: str, credential_id: int):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_id = credential_id
|
||||
|
||||
self.redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# lock used to prevent overlapping renewal of credentials
|
||||
self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}"
|
||||
self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL)
|
||||
|
||||
def __enter__(self) -> "OnyxDBCredentialsProvider":
|
||||
acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL)
|
||||
if not acquired:
|
||||
raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}")
|
||||
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
"""Release the lock when exiting the context."""
|
||||
if self._lock and self._lock.owned():
|
||||
self._lock.release()
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return str(self._credential_id)
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
credential = db_session.execute(
|
||||
select(Credential).where(Credential.id == self._credential_id)
|
||||
).scalar_one()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
return credential.credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
try:
|
||||
credential = db_session.execute(
|
||||
select(Credential)
|
||||
.where(Credential.id == self._credential_id)
|
||||
.with_for_update()
|
||||
).scalar_one()
|
||||
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
credential.credential_json = credential_json
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return True
|
||||
|
||||
|
||||
class OnyxStaticCredentialsProvider(
|
||||
CredentialsProviderInterface["OnyxStaticCredentialsProvider"]
|
||||
):
|
||||
"""Implementation (a very simple one!) to handle static credentials."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
connector_name: str,
|
||||
credential_json: dict[str, Any],
|
||||
):
|
||||
self._tenant_id = tenant_id
|
||||
self._connector_name = connector_name
|
||||
self._credential_json = credential_json
|
||||
|
||||
self._provider_key = str(uuid.uuid4())
|
||||
|
||||
def __enter__(self) -> "OnyxStaticCredentialsProvider":
|
||||
return self
|
||||
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def get_tenant_id(self) -> str | None:
|
||||
return self._tenant_id
|
||||
|
||||
def get_provider_key(self) -> str:
|
||||
return self._provider_key
|
||||
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
return self._credential_json
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
self._credential_json = credential_json
|
||||
|
||||
def is_dynamic(self) -> bool:
|
||||
return False
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -13,6 +12,7 @@ from onyx.connectors.blob.connector import BlobStorageConnector
|
||||
from onyx.connectors.bookstack.connector import BookstackConnector
|
||||
from onyx.connectors.clickup.connector import ClickupConnector
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.discord.connector import DiscordConnector
|
||||
from onyx.connectors.discourse.connector import DiscourseConnector
|
||||
from onyx.connectors.document360.connector import Document360Connector
|
||||
@@ -33,6 +33,7 @@ from onyx.connectors.guru.connector import GuruConnector
|
||||
from onyx.connectors.hubspot.connector import HubSpotConnector
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointConnector
|
||||
from onyx.connectors.interfaces import CredentialsConnector
|
||||
from onyx.connectors.interfaces import EventConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
@@ -58,6 +59,7 @@ from onyx.db.connector import fetch_connector_by_id
|
||||
from onyx.db.credentials import backend_update_credential_json
|
||||
from onyx.db.credentials import fetch_credential_by_id
|
||||
from onyx.db.models import Credential
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
class ConnectorMissingException(Exception):
|
||||
@@ -164,18 +166,21 @@ def instantiate_connector(
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
tenant_id: str | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
if source in DocumentSourceRequiringTenantContext:
|
||||
connector_specific_config["tenant_id"] = tenant_id
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
if isinstance(connector, CredentialsConnector):
|
||||
provider = OnyxDBCredentialsProvider(
|
||||
get_current_tenant_id(), str(source), credential.id
|
||||
)
|
||||
connector.set_credentials_provider(provider)
|
||||
else:
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
return connector
|
||||
|
||||
@@ -184,7 +189,6 @@ def validate_ccpair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None,
|
||||
enforce_creation: bool = True,
|
||||
) -> bool:
|
||||
if INTEGRATION_TESTS_MODE:
|
||||
@@ -216,7 +220,6 @@ def validate_ccpair_for_user(
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -27,8 +27,6 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -165,12 +163,10 @@ class LocalFileConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.tenant_id = tenant_id
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -179,9 +175,8 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
@@ -203,8 +198,6 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import openpyxl # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
@@ -43,12 +45,15 @@ def _extract_sections_basic(
|
||||
) -> list[Section]:
|
||||
mime_type = file["mimeType"]
|
||||
link = file["webViewLink"]
|
||||
supported_file_types = set(item.value for item in GDriveMimeType)
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
if mime_type not in supported_file_types:
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
# ---------------------------
|
||||
# Google Sheets extraction
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
try:
|
||||
sheets_service = build(
|
||||
@@ -109,7 +114,53 @@ def _extract_sections_basic(
|
||||
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
# ---------------------------
|
||||
# Microsoft Excel (.xlsx or .xls) extraction branch
|
||||
elif mime_type in [
|
||||
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
|
||||
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
|
||||
]:
|
||||
try:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
|
||||
tmp.write(response)
|
||||
tmp_path = tmp.name
|
||||
|
||||
section_separator = "\n\n"
|
||||
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
|
||||
|
||||
# Work similarly to the xlsx_to_text function used for file connector
|
||||
# but returns Sections instead of a string
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
text=(
|
||||
f"Sheet: {sheet.title}\n\n"
|
||||
+ section_separator.join(
|
||||
",".join(map(str, row))
|
||||
for row in sheet.iter_rows(
|
||||
min_row=1, values_only=True
|
||||
)
|
||||
if row
|
||||
)
|
||||
),
|
||||
)
|
||||
for sheet in workbook.worksheets
|
||||
]
|
||||
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error extracting data from Excel file '{file['name']}': {e}"
|
||||
)
|
||||
return [
|
||||
Section(link=link, text="Error extracting data from Excel file")
|
||||
]
|
||||
|
||||
# ---------------------------
|
||||
# Export for Google Docs, PPT, and fallback for spreadsheets
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
@@ -128,6 +179,8 @@ def _extract_sections_basic(
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
# ---------------------------
|
||||
# Plain text and Markdown files
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
@@ -141,6 +194,8 @@ def _extract_sections_basic(
|
||||
.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
# ---------------------------
|
||||
# Word, PowerPoint, PDF files
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
@@ -170,7 +225,11 @@ def _extract_sections_basic(
|
||||
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
# Catch-all case, should not happen since there should be specific handling
|
||||
# for each of the supported file types
|
||||
error_message = f"Unsupported file type: {mime_type}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except Exception:
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
@@ -5,6 +5,10 @@ from typing import Any
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
SPREADSHEET_OPEN_FORMAT = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from types import TracebackType
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -111,6 +114,69 @@ class OAuthConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
T = TypeVar("T", bound="CredentialsProviderInterface")
|
||||
|
||||
|
||||
class CredentialsProviderInterface(abc.ABC, Generic[T]):
|
||||
@abc.abstractmethod
|
||||
def __enter__(self) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def __exit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_value: BaseException | None,
|
||||
traceback: TracebackType | None,
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_tenant_id(self) -> str | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_provider_key(self) -> str:
|
||||
"""a unique key that the connector can use to lock around a credential
|
||||
that might be used simultaneously.
|
||||
|
||||
Will typically be the credential id, but can also just be something random
|
||||
in cases when there is nothing to lock (aka static credentials)
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_credentials(self) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def is_dynamic(self) -> bool:
|
||||
"""If dynamic, the credentials may change during usage ... maening the client
|
||||
needs to use the locking features of the credentials provider to operate
|
||||
correctly.
|
||||
|
||||
If static, the client can simply reference the credentials once and use them
|
||||
through the entire indexing run.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class CredentialsConnector(BaseConnector):
|
||||
"""Implement this if the connector needs to be able to read and write credentials
|
||||
on the fly. Typically used with shared credentials/tokens that might be renewed
|
||||
at any time."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_credentials_provider(
|
||||
self, credentials_provider: CredentialsProviderInterface
|
||||
) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -302,29 +302,29 @@ class WebConnector(LoadConnector):
|
||||
playwright, context = start_playwright()
|
||||
restart_playwright = False
|
||||
while to_visit:
|
||||
current_url = to_visit.pop()
|
||||
if current_url in visited_links:
|
||||
initial_url = to_visit.pop()
|
||||
if initial_url in visited_links:
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
visited_links.add(initial_url)
|
||||
|
||||
try:
|
||||
protected_url_check(current_url)
|
||||
protected_url_check(initial_url)
|
||||
except Exception as e:
|
||||
last_error = f"Invalid URL {current_url} due to {e}"
|
||||
last_error = f"Invalid URL {initial_url} due to {e}"
|
||||
logger.warning(last_error)
|
||||
continue
|
||||
|
||||
logger.info(f"Visiting {current_url}")
|
||||
logger.info(f"{len(visited_links)}: Visiting {initial_url}")
|
||||
|
||||
try:
|
||||
check_internet_connection(current_url)
|
||||
check_internet_connection(initial_url)
|
||||
if restart_playwright:
|
||||
playwright, context = start_playwright()
|
||||
restart_playwright = False
|
||||
|
||||
if current_url.split(".")[-1] == "pdf":
|
||||
if initial_url.split(".")[-1] == "pdf":
|
||||
# PDF files are not checked for links
|
||||
response = requests.get(current_url)
|
||||
response = requests.get(initial_url)
|
||||
page_text, metadata = read_pdf_file(
|
||||
file=io.BytesIO(response.content)
|
||||
)
|
||||
@@ -332,10 +332,10 @@ class WebConnector(LoadConnector):
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=current_url,
|
||||
sections=[Section(link=current_url, text=page_text)],
|
||||
id=initial_url,
|
||||
sections=[Section(link=initial_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=current_url.split("/")[-1],
|
||||
semantic_identifier=initial_url.split("/")[-1],
|
||||
metadata=metadata,
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
@@ -347,21 +347,25 @@ class WebConnector(LoadConnector):
|
||||
continue
|
||||
|
||||
page = context.new_page()
|
||||
page_response = page.goto(current_url)
|
||||
page_response = page.goto(initial_url)
|
||||
last_modified = (
|
||||
page_response.header_value("Last-Modified")
|
||||
if page_response
|
||||
else None
|
||||
)
|
||||
final_page = page.url
|
||||
if final_page != current_url:
|
||||
logger.info(f"Redirected to {final_page}")
|
||||
protected_url_check(final_page)
|
||||
current_url = final_page
|
||||
if current_url in visited_links:
|
||||
logger.info("Redirected page already indexed")
|
||||
final_url = page.url
|
||||
if final_url != initial_url:
|
||||
protected_url_check(final_url)
|
||||
initial_url = final_url
|
||||
if initial_url in visited_links:
|
||||
logger.info(
|
||||
f"{len(visited_links)}: {initial_url} redirected to {final_url} - already indexed"
|
||||
)
|
||||
continue
|
||||
visited_links.add(current_url)
|
||||
logger.info(
|
||||
f"{len(visited_links)}: {initial_url} redirected to {final_url}"
|
||||
)
|
||||
visited_links.add(initial_url)
|
||||
|
||||
if self.scroll_before_scraping:
|
||||
scroll_attempts = 0
|
||||
@@ -379,13 +383,13 @@ class WebConnector(LoadConnector):
|
||||
soup = BeautifulSoup(content, "html.parser")
|
||||
|
||||
if self.recursive:
|
||||
internal_links = get_internal_links(base_url, current_url, soup)
|
||||
internal_links = get_internal_links(base_url, initial_url, soup)
|
||||
for link in internal_links:
|
||||
if link not in visited_links:
|
||||
to_visit.append(link)
|
||||
|
||||
if page_response and str(page_response.status)[0] in ("4", "5"):
|
||||
last_error = f"Skipped indexing {current_url} due to HTTP {page_response.status} response"
|
||||
last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response"
|
||||
logger.info(last_error)
|
||||
continue
|
||||
|
||||
@@ -393,12 +397,12 @@ class WebConnector(LoadConnector):
|
||||
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=current_url,
|
||||
id=initial_url,
|
||||
sections=[
|
||||
Section(link=current_url, text=parsed_html.cleaned_text)
|
||||
Section(link=initial_url, text=parsed_html.cleaned_text)
|
||||
],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=parsed_html.title or current_url,
|
||||
semantic_identifier=parsed_html.title or initial_url,
|
||||
metadata={},
|
||||
doc_updated_at=_get_datetime_from_last_modified_header(
|
||||
last_modified
|
||||
@@ -410,7 +414,7 @@ class WebConnector(LoadConnector):
|
||||
|
||||
page.close()
|
||||
except Exception as e:
|
||||
last_error = f"Failed to fetch '{current_url}': {e}"
|
||||
last_error = f"Failed to fetch '{initial_url}': {e}"
|
||||
logger.exception(last_error)
|
||||
playwright.stop()
|
||||
restart_playwright = True
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -73,7 +72,7 @@ def insert_api_key(
|
||||
# Get tenant_id from context var (will be default schema for single tenant)
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
|
||||
api_key = generate_api_key(tenant_id)
|
||||
api_key_user_id = uuid.uuid4()
|
||||
|
||||
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
|
||||
@@ -168,7 +168,7 @@ def get_chat_sessions_by_user(
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
stmt = stmt.order_by(desc(ChatSession.time_updated))
|
||||
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
@@ -962,6 +962,7 @@ def translate_db_message_to_chat_message_detail(
|
||||
chat_message.sub_questions
|
||||
),
|
||||
refined_answer_improvement=chat_message.refined_answer_improvement,
|
||||
is_agentic=chat_message.is_agentic,
|
||||
error=chat_message.error,
|
||||
)
|
||||
|
||||
|
||||
@@ -3,14 +3,13 @@ from typing import Optional
|
||||
from typing import Tuple
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import literal
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import union_all
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause
|
||||
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -26,127 +25,87 @@ def search_chat_sessions(
|
||||
include_onyxbot_flows: bool = False,
|
||||
) -> Tuple[List[ChatSession], bool]:
|
||||
"""
|
||||
Search for chat sessions based on the provided query.
|
||||
If no query is provided, returns recent chat sessions.
|
||||
Fast full-text search on ChatSession + ChatMessage using tsvectors.
|
||||
|
||||
Returns a tuple of (chat_sessions, has_more)
|
||||
If no query is provided, returns the most recent chat sessions.
|
||||
Otherwise, searches both chat messages and session descriptions.
|
||||
|
||||
Returns a tuple of (sessions, has_more) where has_more indicates if
|
||||
there are additional results beyond the requested page.
|
||||
"""
|
||||
offset = (page - 1) * page_size
|
||||
offset_val = (page - 1) * page_size
|
||||
|
||||
# If no search query, we use standard SQLAlchemy pagination
|
||||
# If no query, just return the most recent sessions
|
||||
if not query or not query.strip():
|
||||
stmt = select(ChatSession)
|
||||
if user_id:
|
||||
stmt = (
|
||||
select(ChatSession)
|
||||
.order_by(desc(ChatSession.time_created))
|
||||
.offset(offset_val)
|
||||
.limit(page_size + 1)
|
||||
)
|
||||
if user_id is not None:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.offset(offset).limit(page_size + 1)
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
chat_sessions = result.scalars().all()
|
||||
sessions = result.scalars().all()
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
has_more = len(sessions) > page_size
|
||||
if has_more:
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
sessions = sessions[:page_size]
|
||||
|
||||
return list(chat_sessions), has_more
|
||||
return list(sessions), has_more
|
||||
|
||||
words = query.lower().strip().split()
|
||||
# Otherwise, proceed with full-text search
|
||||
query = query.strip()
|
||||
|
||||
# Message mach subquery
|
||||
message_matches = []
|
||||
for word in words:
|
||||
word_like = f"%{word}%"
|
||||
message_match: Select = (
|
||||
select(ChatMessage.chat_session_id, literal(1.0).label("search_rank"))
|
||||
.join(ChatSession, ChatSession.id == ChatMessage.chat_session_id)
|
||||
.where(func.lower(ChatMessage.message).like(word_like))
|
||||
)
|
||||
|
||||
if user_id:
|
||||
message_match = message_match.where(ChatSession.user_id == user_id)
|
||||
|
||||
message_matches.append(message_match)
|
||||
|
||||
if message_matches:
|
||||
message_matches_query = union_all(*message_matches).alias("message_matches")
|
||||
else:
|
||||
return [], False
|
||||
|
||||
# Description matches
|
||||
description_match: Select = select(
|
||||
ChatSession.id.label("chat_session_id"), literal(0.5).label("search_rank")
|
||||
).where(func.lower(ChatSession.description).like(f"%{query.lower()}%"))
|
||||
|
||||
if user_id:
|
||||
description_match = description_match.where(ChatSession.user_id == user_id)
|
||||
base_conditions = []
|
||||
if user_id is not None:
|
||||
base_conditions.append(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
description_match = description_match.where(ChatSession.onyxbot_flow.is_(False))
|
||||
base_conditions.append(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
description_match = description_match.where(ChatSession.deleted.is_(False))
|
||||
base_conditions.append(ChatSession.deleted.is_(False))
|
||||
|
||||
# Combine all match sources
|
||||
combined_matches = union_all(
|
||||
message_matches_query.select(), description_match
|
||||
).alias("combined_matches")
|
||||
message_tsv: ColumnClause = column("message_tsv")
|
||||
description_tsv: ColumnClause = column("description_tsv")
|
||||
|
||||
# Use CTE to group and get max rank
|
||||
session_ranks = (
|
||||
select(
|
||||
combined_matches.c.chat_session_id,
|
||||
func.max(combined_matches.c.search_rank).label("rank"),
|
||||
)
|
||||
.group_by(combined_matches.c.chat_session_id)
|
||||
.alias("session_ranks")
|
||||
ts_query = func.plainto_tsquery("english", query)
|
||||
|
||||
description_session_ids = (
|
||||
select(ChatSession.id)
|
||||
.where(*base_conditions)
|
||||
.where(description_tsv.op("@@")(ts_query))
|
||||
)
|
||||
|
||||
# Get ranked sessions with pagination
|
||||
ranked_query = (
|
||||
db_session.query(session_ranks.c.chat_session_id, session_ranks.c.rank)
|
||||
.order_by(desc(session_ranks.c.rank), session_ranks.c.chat_session_id)
|
||||
.offset(offset)
|
||||
message_session_ids = (
|
||||
select(ChatMessage.chat_session_id)
|
||||
.join(ChatSession, ChatMessage.chat_session_id == ChatSession.id)
|
||||
.where(*base_conditions)
|
||||
.where(message_tsv.op("@@")(ts_query))
|
||||
)
|
||||
|
||||
combined_ids = description_session_ids.union(message_session_ids).alias(
|
||||
"combined_ids"
|
||||
)
|
||||
|
||||
final_stmt = (
|
||||
select(ChatSession)
|
||||
.join(combined_ids, ChatSession.id == combined_ids.c.id)
|
||||
.order_by(desc(ChatSession.time_created))
|
||||
.distinct()
|
||||
.offset(offset_val)
|
||||
.limit(page_size + 1)
|
||||
.options(joinedload(ChatSession.persona))
|
||||
)
|
||||
|
||||
result = ranked_query.all()
|
||||
session_objs = db_session.execute(final_stmt).scalars().all()
|
||||
|
||||
# Extract session IDs and ranks
|
||||
session_ids_with_ranks = {row.chat_session_id: row.rank for row in result}
|
||||
session_ids = list(session_ids_with_ranks.keys())
|
||||
|
||||
if not session_ids:
|
||||
return [], False
|
||||
|
||||
# Now, let's query the actual ChatSession objects using the IDs
|
||||
stmt = select(ChatSession).where(ChatSession.id.in_(session_ids))
|
||||
|
||||
if user_id:
|
||||
stmt = stmt.where(ChatSession.user_id == user_id)
|
||||
if not include_onyxbot_flows:
|
||||
stmt = stmt.where(ChatSession.onyxbot_flow.is_(False))
|
||||
if not include_deleted:
|
||||
stmt = stmt.where(ChatSession.deleted.is_(False))
|
||||
|
||||
# Full objects with eager loading
|
||||
result = db_session.execute(stmt.options(joinedload(ChatSession.persona)))
|
||||
chat_sessions = result.scalars().all()
|
||||
|
||||
# Sort based on above ranking
|
||||
chat_sessions = sorted(
|
||||
chat_sessions,
|
||||
key=lambda session: (
|
||||
-session_ids_with_ranks.get(session.id, 0), # Rank (higher first)
|
||||
session.time_created.timestamp() * -1, # Then by time (newest first)
|
||||
),
|
||||
)
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
has_more = len(session_objs) > page_size
|
||||
if has_more:
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
session_objs = session_objs[:page_size]
|
||||
|
||||
return chat_sessions, has_more
|
||||
return list(session_objs), has_more
|
||||
|
||||
@@ -360,18 +360,13 @@ def backend_update_credential_json(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_credential(
|
||||
def _delete_credential_internal(
|
||||
credential: Credential,
|
||||
credential_id: int,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
"""Internal utility function to handle the actual deletion of a credential"""
|
||||
associated_connectors = (
|
||||
db_session.query(ConnectorCredentialPair)
|
||||
.filter(ConnectorCredentialPair.credential_id == credential_id)
|
||||
@@ -416,6 +411,35 @@ def delete_credential(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def delete_credential_for_user(
|
||||
credential_id: int,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Delete a credential that belongs to a specific user"""
|
||||
credential = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(
|
||||
f"Credential by provided id {credential_id} does not exist or does not belong to user"
|
||||
)
|
||||
|
||||
_delete_credential_internal(credential, credential_id, db_session, force)
|
||||
|
||||
|
||||
def delete_credential(
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
force: bool = False,
|
||||
) -> None:
|
||||
"""Delete a credential regardless of ownership (admin function)"""
|
||||
credential = fetch_credential_by_id(credential_id, db_session)
|
||||
if credential is None:
|
||||
raise ValueError(f"Credential by provided id {credential_id} does not exist")
|
||||
|
||||
_delete_credential_internal(credential, credential_id, db_session, force)
|
||||
|
||||
|
||||
def create_initial_public_credential(db_session: Session) -> None:
|
||||
error_msg = (
|
||||
"DB is not in a valid initial state."
|
||||
|
||||
@@ -258,11 +258,11 @@ class SqlEngine:
|
||||
cls._engine = None
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
return [POSTGRES_DEFAULT_SCHEMA]
|
||||
|
||||
with get_session_with_shared_schema() as session:
|
||||
result = session.execute(
|
||||
@@ -417,7 +417,7 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import validates
|
||||
from typing_extensions import TypedDict # noreorder
|
||||
from uuid import UUID
|
||||
|
||||
@@ -25,6 +26,7 @@ from sqlalchemy import ForeignKey
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import Index
|
||||
from sqlalchemy import Integer
|
||||
|
||||
from sqlalchemy import Sequence
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import Text
|
||||
@@ -205,6 +207,10 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)",
|
||||
)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
@property
|
||||
def password_configured(self) -> bool:
|
||||
"""
|
||||
@@ -2269,6 +2275,10 @@ class UserTenantMapping(Base):
|
||||
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
|
||||
# This is a mapping from tenant IDs to anonymous user paths
|
||||
class TenantAnonymousUserPath(Base):
|
||||
|
||||
@@ -100,9 +100,14 @@ def _add_user_filters(
|
||||
.correlate(Persona)
|
||||
)
|
||||
else:
|
||||
where_clause |= Persona.is_public == True # noqa: E712
|
||||
where_clause &= Persona.is_visible == True # noqa: E712
|
||||
# Group the public persona conditions
|
||||
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
||||
Persona.is_visible == True # noqa: E712
|
||||
)
|
||||
|
||||
where_clause |= public_condition
|
||||
where_clause |= Persona__User.user_id == user.id
|
||||
|
||||
where_clause |= Persona.user_id == user.id
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
# Vespa's Document API.
|
||||
def get_document_chunk_ids(
|
||||
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
large_chunks_enabled: bool,
|
||||
) -> list[UUID]:
|
||||
doc_chunk_ids = []
|
||||
@@ -139,7 +139,7 @@ def get_uuid_from_chunk_info(
|
||||
*,
|
||||
document_id: str,
|
||||
chunk_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
large_chunk_id: int | None = None,
|
||||
) -> UUID:
|
||||
"""NOTE: be VERY carefuly about changing this function. If changed without a migration,
|
||||
@@ -154,7 +154,7 @@ def get_uuid_from_chunk_info(
|
||||
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
|
||||
)
|
||||
unique_identifier_string = "_".join([doc_str, chunk_index])
|
||||
if tenant_id and MULTI_TENANT:
|
||||
if MULTI_TENANT:
|
||||
unique_identifier_string += "_" + tenant_id
|
||||
|
||||
uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
@@ -43,7 +43,7 @@ class IndexBatchParams:
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
tenant_id: str | None
|
||||
tenant_id: str
|
||||
large_chunks_enabled: bool
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class Deletable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ class Updatable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
@@ -270,9 +270,7 @@ class Updatable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
requests. Each update request in the list applies its changes to a list of document ids.
|
||||
|
||||
@@ -468,9 +468,7 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
@@ -618,7 +616,7 @@ class VespaIndex(DocumentIndex):
|
||||
doc_id: str,
|
||||
*,
|
||||
chunk_count: int | None,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
@@ -661,7 +659,7 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
total_chunks_deleted = 0
|
||||
|
||||
@@ -158,8 +158,8 @@ def index_doc_batch_with_handler(
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
@@ -317,8 +317,8 @@ def index_doc_batch(
|
||||
document_index: DocumentIndex,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> IndexingPipelineResult:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
@@ -525,9 +525,9 @@ def build_indexing_pipeline(
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
|
||||
@@ -84,7 +84,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
negative -> ranked lower.
|
||||
"""
|
||||
|
||||
tenant_id: str | None = None
|
||||
tenant_id: str
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
@@ -96,7 +96,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
|
||||
@@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
@@ -219,7 +218,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
setup_onyx(db_session, None)
|
||||
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
|
||||
else:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
@@ -323,7 +322,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_BOT_NUM_DOCS_TO_DISPLAY
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.chat import get_chat_session_by_message_id
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChannelConfig
|
||||
from onyx.onyxbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID
|
||||
from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID
|
||||
@@ -410,12 +410,11 @@ def _build_qa_response_blocks(
|
||||
|
||||
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str | None,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
raise ValueError("No message id provided to build continue in web ui block")
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
chat_session = get_chat_session_by_message_id(
|
||||
db_session=db_session,
|
||||
message_id=message_id,
|
||||
@@ -482,7 +481,6 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatOnyxBotResponse,
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
@@ -517,7 +515,6 @@ def build_slack_response_blocks(
|
||||
if channel_conf and channel_conf.get("show_continue_in_web_ui"):
|
||||
web_follow_up_block.append(
|
||||
_build_continue_in_web_ui_block(
|
||||
tenant_id=tenant_id,
|
||||
message_id=answer.chat_message_id,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -11,7 +11,7 @@ from onyx.configs.constants import SearchFeedbackType
|
||||
from onyx.configs.onyxbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.feedback import create_chat_message_feedback
|
||||
from onyx.db.feedback import create_doc_retrieval_feedback
|
||||
from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks
|
||||
@@ -114,7 +114,7 @@ def handle_generate_answer_button(
|
||||
thread_ts=thread_ts,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -136,7 +136,6 @@ def handle_generate_answer_button(
|
||||
slack_channel_config=slack_channel_config,
|
||||
receiver_ids=None,
|
||||
client=client.web_client,
|
||||
tenant_id=client.tenant_id,
|
||||
channel=channel_id,
|
||||
logger=logger,
|
||||
feedback_reminder_id=None,
|
||||
@@ -151,11 +150,10 @@ def handle_slack_feedback(
|
||||
user_id_to_post_confirmation: str,
|
||||
channel_id_to_post_confirmation: str,
|
||||
thread_ts_to_post_confirmation: str,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]:
|
||||
create_chat_message_feedback(
|
||||
is_positive=feedback_type == LIKE_BLOCK_ACTION_ID,
|
||||
@@ -246,7 +244,7 @@ def handle_followup_button(
|
||||
|
||||
tag_ids: list[str] = []
|
||||
group_ids: list[str] = []
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
channel_name, is_dm = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel_id
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.onyxbot_configs import DANSWER_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import DANSWER_REACT_EMOJI
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks
|
||||
@@ -109,7 +109,6 @@ def handle_message(
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
client: WebClient,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
@@ -135,9 +134,7 @@ def handle_message(
|
||||
action = "slack_tag_message"
|
||||
elif is_bot_dm:
|
||||
action = "slack_dm_message"
|
||||
slack_usage_report(
|
||||
action=action, sender_id=sender_id, client=client, tenant_id=tenant_id
|
||||
)
|
||||
slack_usage_report(action=action, sender_id=sender_id, client=client)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
@@ -218,7 +215,7 @@ def handle_message(
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if message_info.email:
|
||||
add_slack_user_if_not_exists(db_session, message_info.email)
|
||||
|
||||
@@ -244,6 +241,5 @@ def handle_message(
|
||||
channel=channel,
|
||||
logger=logger,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
return issue_with_regular_answer
|
||||
|
||||
@@ -24,7 +24,6 @@ from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
@@ -72,7 +71,6 @@ def handle_regular_answer(
|
||||
channel: str,
|
||||
logger: OnyxLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
@@ -87,7 +85,7 @@ def handle_regular_answer(
|
||||
user = None
|
||||
if message_info.is_bot_dm:
|
||||
if message_info.email:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
@@ -96,7 +94,7 @@ def handle_regular_answer(
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona
|
||||
if not persona:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
@@ -157,7 +155,7 @@ def handle_regular_answer(
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, onyx_user: User | None
|
||||
) -> ChatOnyxBotResponse:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=onyx_user,
|
||||
@@ -197,7 +195,7 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=user_message.message,
|
||||
user=user,
|
||||
@@ -361,7 +359,6 @@ def handle_regular_answer(
|
||||
return True
|
||||
|
||||
all_blocks = build_slack_response_blocks(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
channel_conf=channel_conf,
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.context.search.retrieval.search_runner import (
|
||||
download_nltk_data,
|
||||
)
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.models import SlackBot
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -92,6 +93,7 @@ from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -123,13 +125,13 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
class SlackbotHandler:
|
||||
def __init__(self) -> None:
|
||||
logger.info("Initializing SlackbotHandler")
|
||||
self.tenant_ids: Set[str | None] = set()
|
||||
self.tenant_ids: Set[str] = set()
|
||||
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
|
||||
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
|
||||
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
|
||||
|
||||
# Store Redis lock objects here so we can release them properly
|
||||
self.redis_locks: Dict[str | None, Lock] = {}
|
||||
self.redis_locks: Dict[str, Lock] = {}
|
||||
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
@@ -193,7 +195,7 @@ class SlackbotHandler:
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
|
||||
def _manage_clients_per_tenant(
|
||||
self, db_session: Session, tenant_id: str | None, bot: SlackBot
|
||||
self, db_session: Session, tenant_id: str, bot: SlackBot
|
||||
) -> None:
|
||||
"""
|
||||
- If the tokens are missing or empty, close the socket client and remove them.
|
||||
@@ -347,7 +349,7 @@ class SlackbotHandler:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Attempt to fetch Slack bots
|
||||
try:
|
||||
bots = list(fetch_slack_bots(db_session=db_session))
|
||||
@@ -385,7 +387,7 @@ class SlackbotHandler:
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def _remove_tenant(self, tenant_id: str | None) -> None:
|
||||
def _remove_tenant(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
|
||||
(Lock release now happens in `acquire_tenants()`, not here.)
|
||||
@@ -415,7 +417,7 @@ class SlackbotHandler:
|
||||
)
|
||||
|
||||
def start_socket_client(
|
||||
self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
|
||||
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
|
||||
) -> None:
|
||||
socket_client: TenantSocketModeClient = _get_socket_client(
|
||||
slack_bot_tokens, tenant_id, slack_bot_id
|
||||
@@ -586,7 +588,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
channel_name, _ = get_channel_name_from_id(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
@@ -680,7 +682,6 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
|
||||
user_id_to_post_confirmation=user_id,
|
||||
channel_id_to_post_confirmation=channel_id,
|
||||
thread_ts_to_post_confirmation=thread_ts,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
|
||||
query_event_id, _, _ = decompose_action_id(feedback_id)
|
||||
@@ -796,8 +797,9 @@ def process_message(
|
||||
respond_every_channel: bool = DANSWER_BOT_RESPOND_EVERY_CHANNEL,
|
||||
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
logger.debug(
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {client.tenant_id}"
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {tenant_id}"
|
||||
)
|
||||
|
||||
# Throw out requests that can't or shouldn't be handled
|
||||
@@ -810,50 +812,39 @@ def process_message(
|
||||
client=client.web_client, channel_id=channel
|
||||
)
|
||||
|
||||
token: Token[str | None] | None = None
|
||||
# Set the current tenant ID at the beginning for all DB calls within this thread
|
||||
if client.tenant_id:
|
||||
logger.info(f"Setting tenant ID to {client.tenant_id}")
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(client.tenant_id)
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=client.tenant_id) as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_channel_config = get_slack_channel_config_for_bot_and_channel(
|
||||
db_session=db_session,
|
||||
slack_bot_id=client.slack_bot_id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
|
||||
follow_up = bool(
|
||||
slack_channel_config.channel_config
|
||||
and slack_channel_config.channel_config.get("follow_up_tags")
|
||||
is not None
|
||||
)
|
||||
follow_up = bool(
|
||||
slack_channel_config.channel_config
|
||||
and slack_channel_config.channel_config.get("follow_up_tags") is not None
|
||||
)
|
||||
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
feedback_reminder_id = schedule_feedback_reminder(
|
||||
details=details, client=client.web_client, include_followup=follow_up
|
||||
)
|
||||
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_channel_config=slack_channel_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
tenant_id=client.tenant_id,
|
||||
)
|
||||
failed = handle_message(
|
||||
message_info=details,
|
||||
slack_channel_config=slack_channel_config,
|
||||
client=client.web_client,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender_id,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
finally:
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
if failed:
|
||||
if feedback_reminder_id:
|
||||
remove_scheduled_feedback_reminder(
|
||||
client=client.web_client,
|
||||
channel=details.sender_id,
|
||||
msg_id=feedback_reminder_id,
|
||||
)
|
||||
# Skipping answering due to pre-filtering is not considered a failure
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
@@ -912,7 +903,7 @@ def create_process_slack_event() -> (
|
||||
|
||||
|
||||
def _get_socket_client(
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/slack_bot_setup
|
||||
|
||||
@@ -4,6 +4,8 @@ import re
|
||||
import string
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -30,7 +32,7 @@ from onyx.configs.onyxbot_configs import (
|
||||
)
|
||||
from onyx.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@@ -43,6 +45,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.text_processing import replace_whitespaces_w_space
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -569,9 +572,7 @@ def read_slack_thread(
|
||||
return thread_messages
|
||||
|
||||
|
||||
def slack_usage_report(
|
||||
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
|
||||
) -> None:
|
||||
def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None:
|
||||
if DISABLE_TELEMETRY:
|
||||
return
|
||||
|
||||
@@ -583,14 +584,13 @@ def slack_usage_report(
|
||||
logger.warning("Unable to find sender email")
|
||||
|
||||
if sender_email is not None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
onyx_user = get_user_by_email(email=sender_email, db_session=db_session)
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.USAGE,
|
||||
data={"action": action},
|
||||
user_id=str(onyx_user.id) if onyx_user else "Non-Onyx-Or-No-Auth-User",
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -663,9 +663,30 @@ def get_feedback_visibility() -> FeedbackVisibility:
|
||||
|
||||
|
||||
class TenantSocketModeClient(SocketModeClient):
|
||||
def __init__(
|
||||
self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any
|
||||
):
|
||||
def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tenant_id = tenant_id
|
||||
self._tenant_id = tenant_id
|
||||
self.slack_bot_id = slack_bot_id
|
||||
|
||||
@contextmanager
|
||||
def _set_tenant_context(self) -> Generator[None, None, None]:
|
||||
token = None
|
||||
try:
|
||||
if self._tenant_id:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self._tenant_id)
|
||||
yield
|
||||
finally:
|
||||
if token:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def enqueue_message(self, message: str) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().enqueue_message(message)
|
||||
|
||||
def process_message(self) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().process_message()
|
||||
|
||||
def run_message_listeners(self, message: dict, raw_message: str) -> None:
|
||||
with self._set_tenant_context():
|
||||
super().run_message_listeners(message, raw_message)
|
||||
|
||||
@@ -16,10 +16,10 @@ class RedisConnector:
|
||||
"""Composes several classes to simplify interacting with a connector and its
|
||||
associated background tasks / associated redis interactions."""
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
"""id: a connector credential pair id"""
|
||||
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id: int = id
|
||||
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
# documents that should be skipped
|
||||
@@ -60,7 +60,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""We can limit the number of tasks generated here, which is useful to prevent
|
||||
one tenant from overwhelming the sync queue.
|
||||
|
||||
@@ -39,8 +39,8 @@ class RedisConnectorDelete:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPermissionSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ class RedisConnectorExternalGroupSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,12 +52,12 @@ class RedisConnectorIndex:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPrune:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ class RedisConnectorStop:
|
||||
TIMEOUT_PREFIX = f"{PREFIX}_timeout"
|
||||
TIMEOUT_TTL = 300
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id: int = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -58,7 +58,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
document set up to date over multiple batches.
|
||||
|
||||
@@ -14,8 +14,8 @@ class RedisObjectHelper(ABC):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: str):
|
||||
self._tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: str):
|
||||
self._tenant_id: str = tenant_id
|
||||
self._id: str = id
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -87,7 +87,7 @@ class RedisObjectHelper(ABC):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""First element should be the number of actual tasks generated, second should
|
||||
be the number of docs that were candidates to be synced for the cc pair.
|
||||
|
||||
@@ -24,7 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -59,7 +59,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
user group up to date over multiple batches.
|
||||
|
||||
@@ -37,13 +37,15 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_indexable_chunks(
|
||||
preprocessed_docs: list[dict],
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[list[Document], list[DocMetadataAwareIndexChunk]]:
|
||||
ids_to_documents = {}
|
||||
chunks = []
|
||||
@@ -86,7 +88,7 @@ def _create_indexable_chunks(
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=preprocessed_doc["title_embedding"],
|
||||
tenant_id=tenant_id,
|
||||
tenant_id=tenant_id if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA,
|
||||
access=default_public_access,
|
||||
document_sets=set(),
|
||||
boost=DEFAULT_BOOST,
|
||||
@@ -111,7 +113,7 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
|
||||
|
||||
def seed_initial_documents(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Seed initial documents so users don't have an empty index to start
|
||||
|
||||
@@ -620,7 +620,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
try:
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session, tenant_id)
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
@@ -646,7 +646,6 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValidationError as e:
|
||||
# If validation fails, delete the connector and commit the changes
|
||||
# Ensures we don't leave invalid connectors in the database
|
||||
@@ -660,10 +659,14 @@ def associate_credential_to_connector(
|
||||
)
|
||||
except IntegrityError as e:
|
||||
logger.error(f"IntegrityError: {e}")
|
||||
delete_connector(db_session, connector_id)
|
||||
db_session.commit()
|
||||
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error: {e}")
|
||||
|
||||
raise HTTPException(status_code=500, detail="Unexpected error")
|
||||
|
||||
|
||||
|
||||
@@ -902,7 +902,6 @@ def create_connector_with_mock_credential(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -13,12 +13,12 @@ from onyx.db.credentials import cleanup_gmail_credentials
|
||||
from onyx.db.credentials import create_credential
|
||||
from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE
|
||||
from onyx.db.credentials import delete_credential
|
||||
from onyx.db.credentials import delete_credential_for_user
|
||||
from onyx.db.credentials import fetch_credential_by_id_for_user
|
||||
from onyx.db.credentials import fetch_credentials_by_source_for_user
|
||||
from onyx.db.credentials import fetch_credentials_for_user
|
||||
from onyx.db.credentials import swap_credentials_connector
|
||||
from onyx.db.credentials import update_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.db.models import User
|
||||
@@ -89,7 +89,7 @@ def delete_credential_by_id_admin(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
"""Same as the user endpoint, but can delete any credential (not just the user's own)"""
|
||||
delete_credential(db_session=db_session, credential_id=credential_id, user=None)
|
||||
delete_credential(db_session=db_session, credential_id=credential_id)
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
)
|
||||
@@ -100,13 +100,11 @@ def swap_credentials_for_connector(
|
||||
credential_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
validate_ccpair_for_user(
|
||||
credential_swap_req.connector_id,
|
||||
credential_swap_req.new_credential_id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
connector_credential_pair = swap_credentials_connector(
|
||||
@@ -245,7 +243,7 @@ def delete_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential(
|
||||
delete_credential_for_user(
|
||||
credential_id,
|
||||
user,
|
||||
db_session,
|
||||
@@ -262,7 +260,7 @@ def force_delete_credential_by_id(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse:
|
||||
delete_credential(credential_id, user, db_session, True)
|
||||
delete_credential_for_user(credential_id, user, db_session, True)
|
||||
|
||||
return StatusResponse(
|
||||
success=True, message="Credential deleted successfully", data=credential_id
|
||||
|
||||
@@ -49,6 +49,7 @@ def get_folders(
|
||||
name=chat_session.description,
|
||||
persona_id=chat_session.persona_id,
|
||||
time_created=chat_session.time_created.isoformat(),
|
||||
time_updated=chat_session.time_updated.isoformat(),
|
||||
shared_status=chat_session.shared_status,
|
||||
folder_id=folder.id,
|
||||
)
|
||||
|
||||
@@ -343,7 +343,8 @@ def list_bot_configs(
|
||||
]
|
||||
|
||||
|
||||
MAX_CHANNELS = 200
|
||||
MAX_SLACK_PAGES = 5
|
||||
SLACK_API_CHANNELS_PER_PAGE = 100
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -355,8 +356,8 @@ def get_all_channels_from_slack_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[SlackChannel]:
|
||||
"""
|
||||
Fetches all channels from the Slack API.
|
||||
If the workspace has 200 or more channels, we raise an error.
|
||||
Fetches channels the bot is a member of from the Slack API.
|
||||
Handles pagination with a limit to avoid excessive API calls.
|
||||
"""
|
||||
tokens = fetch_slack_bot_tokens(db_session, bot_id)
|
||||
if not tokens or "bot_token" not in tokens:
|
||||
@@ -365,28 +366,60 @@ def get_all_channels_from_slack_api(
|
||||
)
|
||||
|
||||
client = WebClient(token=tokens["bot_token"])
|
||||
all_channels = []
|
||||
next_cursor = None
|
||||
current_page = 0
|
||||
|
||||
try:
|
||||
response = client.conversations_list(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=MAX_CHANNELS,
|
||||
)
|
||||
# Use users_conversations with limited pagination
|
||||
while current_page < MAX_SLACK_PAGES:
|
||||
current_page += 1
|
||||
|
||||
# Make API call with cursor if we have one
|
||||
if next_cursor:
|
||||
response = client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
cursor=next_cursor,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
else:
|
||||
response = client.users_conversations(
|
||||
types="public_channel,private_channel",
|
||||
exclude_archived=True,
|
||||
limit=SLACK_API_CHANNELS_PER_PAGE,
|
||||
)
|
||||
|
||||
# Add channels to our list
|
||||
if "channels" in response and response["channels"]:
|
||||
all_channels.extend(response["channels"])
|
||||
|
||||
# Check if we need to paginate
|
||||
if (
|
||||
"response_metadata" in response
|
||||
and "next_cursor" in response["response_metadata"]
|
||||
):
|
||||
next_cursor = response["response_metadata"]["next_cursor"]
|
||||
if next_cursor:
|
||||
if current_page == MAX_SLACK_PAGES:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Workspace has too many channels to paginate over in this call.",
|
||||
)
|
||||
continue
|
||||
|
||||
# If we get here, no more pages
|
||||
break
|
||||
|
||||
channels = [
|
||||
SlackChannel(id=channel["id"], name=channel["name"])
|
||||
for channel in response["channels"]
|
||||
for channel in all_channels
|
||||
]
|
||||
|
||||
if len(channels) == MAX_CHANNELS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Workspace has {MAX_CHANNELS} or more channels.",
|
||||
)
|
||||
|
||||
return channels
|
||||
|
||||
except SlackApiError as e:
|
||||
# Handle rate limiting or other API errors
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error fetching channels from Slack API: {str(e)}",
|
||||
|
||||
@@ -147,9 +147,11 @@ def list_threads(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
|
||||
@@ -119,6 +119,7 @@ def get_user_chat_sessions(
|
||||
name=chat.description,
|
||||
persona_id=chat.persona_id,
|
||||
time_created=chat.time_created.isoformat(),
|
||||
time_updated=chat.time_updated.isoformat(),
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
|
||||
@@ -181,6 +181,7 @@ class ChatSessionDetails(BaseModel):
|
||||
name: str | None
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
time_updated: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
@@ -241,6 +242,7 @@ class ChatMessageDetail(BaseModel):
|
||||
files: list[FileDescriptor]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
refined_answer_improvement: bool | None = None
|
||||
is_agentic: bool | None = None
|
||||
error: str | None = None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
|
||||
@@ -159,6 +159,7 @@ def get_user_search_sessions(
|
||||
name=sessions_with_documents_dict[search.id],
|
||||
persona_id=search.persona_id,
|
||||
time_created=search.time_created.isoformat(),
|
||||
time_updated=search.time_updated.isoformat(),
|
||||
shared_status=search.shared_status,
|
||||
folder_id=search.folder_id,
|
||||
current_alternate_model=search.current_alternate_model,
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user