mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 05:35:46 +00:00
Compare commits
38 Commits
search2
...
v1.2.0-clo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
555630070b | ||
|
|
1d16c96009 | ||
|
|
297720c132 | ||
|
|
bd4bd00cef | ||
|
|
07c482f727 | ||
|
|
cf193dee29 | ||
|
|
1b47fa2700 | ||
|
|
e1a305d18a | ||
|
|
e2233d22c9 | ||
|
|
20d1175312 | ||
|
|
7117774287 | ||
|
|
77f2660bb2 | ||
|
|
1b2f4f3b87 | ||
|
|
d85b55a9d2 | ||
|
|
e2bae5a2d9 | ||
|
|
cc9c76c4fb | ||
|
|
258e08abcd | ||
|
|
67047e42a7 | ||
|
|
146628e734 | ||
|
|
c1d4b08132 | ||
|
|
f3f47d0709 | ||
|
|
fe26a1bfcc | ||
|
|
554cd0f891 | ||
|
|
f87d3e9849 | ||
|
|
72cdada893 | ||
|
|
c442ebaff6 | ||
|
|
56f16d107e | ||
|
|
0157ae099a | ||
|
|
565fb42457 | ||
|
|
a50a8b4a12 | ||
|
|
4baf4e7d96 | ||
|
|
8b7ab2eb66 | ||
|
|
1f75f3633e | ||
|
|
650884d76a | ||
|
|
8722bdb414 | ||
|
|
71037678c3 | ||
|
|
68de1015e1 | ||
|
|
e2b3a6e144 |
2
.github/CODEOWNERS
vendored
2
.github/CODEOWNERS
vendored
@@ -1 +1,3 @@
|
||||
* @onyx-dot-app/onyx-core-team
|
||||
# Helm charts Owners
|
||||
/helm/ @justin-tahara
|
||||
|
||||
42
.github/workflows/helm-chart-releases.yml
vendored
Normal file
42
.github/workflows/helm-chart-releases.yml
vendored
Normal file
@@ -0,0 +1,42 @@
|
||||
name: Release Onyx Helm Charts
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions: write-all
|
||||
|
||||
jobs:
|
||||
release:
|
||||
permissions:
|
||||
contents: write
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Configure Git
|
||||
run: |
|
||||
git config user.name "$GITHUB_ACTOR"
|
||||
git config user.email "$GITHUB_ACTOR@users.noreply.github.com"
|
||||
|
||||
- name: Install Helm
|
||||
uses: azure/setup-helm@v4
|
||||
with:
|
||||
version: v3.12.1
|
||||
|
||||
- name: Add Required Helm Repositories
|
||||
run: |
|
||||
helm repo add bitnami https://charts.bitnami.com/bitnami
|
||||
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
helm repo update
|
||||
|
||||
- name: Run chart-releaser
|
||||
uses: helm/chart-releaser-action@v1.7.0
|
||||
with:
|
||||
charts_dir: deployment/helm/charts
|
||||
env:
|
||||
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"
|
||||
@@ -0,0 +1,30 @@
|
||||
"""add_doc_metadata_field_in_document_model
|
||||
|
||||
Revision ID: 3fc5d75723b3
|
||||
Revises: 2f95e36923e6
|
||||
Create Date: 2025-07-28 18:45:37.985406
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3fc5d75723b3"
|
||||
down_revision = "2f95e36923e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"document",
|
||||
sa.Column(
|
||||
"doc_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "doc_metadata")
|
||||
@@ -0,0 +1,132 @@
|
||||
"""add file names to file connector config
|
||||
|
||||
Revision ID: 62c3a055a141
|
||||
Revises: 3fc5d75723b3
|
||||
Create Date: 2025-07-30 17:01:24.417551
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import os
|
||||
import logging
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "62c3a055a141"
|
||||
down_revision = "3fc5d75723b3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
SKIP_FILE_NAME_MIGRATION = (
|
||||
os.environ.get("SKIP_FILE_NAME_MIGRATION", "true").lower() == "true"
|
||||
)
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if SKIP_FILE_NAME_MIGRATION:
|
||||
logger.info(
|
||||
"Skipping file name migration. Hint: set SKIP_FILE_NAME_MIGRATION=false to run this migration"
|
||||
)
|
||||
return
|
||||
logger.info("Running file name migration")
|
||||
# Get connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Get all FILE connectors with their configs
|
||||
file_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'FILE'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for connector_id, config in file_connectors:
|
||||
# Parse config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
# Get file_locations list
|
||||
file_locations = config.get("file_locations", [])
|
||||
|
||||
# Get display names for each file_id
|
||||
file_names = []
|
||||
for file_id in file_locations:
|
||||
result = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT display_name
|
||||
FROM file_record
|
||||
WHERE file_id = :file_id
|
||||
"""
|
||||
),
|
||||
{"file_id": file_id},
|
||||
).fetchone()
|
||||
|
||||
if result:
|
||||
file_names.append(result[0])
|
||||
else:
|
||||
file_names.append(file_id) # Should not happen
|
||||
|
||||
# Add file_names to config
|
||||
new_config = dict(config)
|
||||
new_config["file_names"] = file_names
|
||||
|
||||
# Update the connector
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get connection
|
||||
conn = op.get_bind()
|
||||
|
||||
# Remove file_names from all FILE connectors
|
||||
file_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'FILE'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for connector_id, config in file_connectors:
|
||||
# Parse config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
# Remove file_names if it exists
|
||||
if "file_names" in config:
|
||||
new_config = dict(config)
|
||||
del new_config["file_names"]
|
||||
|
||||
# Update the connector
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{
|
||||
"connector_id": connector_id,
|
||||
"new_config": json.dumps(new_config),
|
||||
},
|
||||
)
|
||||
@@ -47,6 +47,7 @@ from onyx.connectors.factory import validate_ccpair_for_user
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import get_document_ids_for_connector_credential_pair
|
||||
from onyx.db.document import get_documents_for_connector_credential_pair_limited_columns
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
@@ -58,7 +59,9 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import is_retryable_sqlalchemy_error
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
@@ -498,16 +501,31 @@ def connector_permission_sync_generator_task(
|
||||
# this is can be used to determine documents that are "missing" and thus
|
||||
# should no longer be accessible. The decision as to whether we should find
|
||||
# every document during the doc sync process is connector-specific.
|
||||
def fetch_all_existing_docs_fn() -> list[str]:
|
||||
return get_document_ids_for_connector_credential_pair(
|
||||
def fetch_all_existing_docs_fn(
|
||||
sort_order: SortOrder | None = None,
|
||||
) -> list[DocumentRow]:
|
||||
result = get_documents_for_connector_credential_pair_limited_columns(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
sort_order=sort_order,
|
||||
)
|
||||
return list(result)
|
||||
|
||||
def fetch_all_existing_docs_ids_fn() -> list[str]:
|
||||
result = get_document_ids_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
)
|
||||
return result
|
||||
|
||||
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
|
||||
document_external_accesses = doc_sync_func(
|
||||
cc_pair, fetch_all_existing_docs_fn, callback
|
||||
cc_pair,
|
||||
fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn,
|
||||
callback,
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
|
||||
@@ -71,6 +71,19 @@ GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# GitHub
|
||||
#####
|
||||
# In seconds, default is 5 minutes
|
||||
GITHUB_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GITHUB_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
|
||||
|
||||
#####
|
||||
# Slack
|
||||
#####
|
||||
|
||||
@@ -18,9 +18,9 @@
|
||||
<!-- <document type="danswer_chunk" mode="index" /> -->
|
||||
{{ document_elements }}
|
||||
</documents>
|
||||
<nodes count="75">
|
||||
<resources vcpu="8.0" memory="64.0Gb" architecture="arm64" storage-type="local"
|
||||
disk="474.0Gb" />
|
||||
<nodes count="60">
|
||||
<resources vcpu="8.0" memory="128.0Gb" architecture="arm64" storage-type="local"
|
||||
disk="475.0Gb" />
|
||||
</nodes>
|
||||
<engine>
|
||||
<proton>
|
||||
|
||||
@@ -6,6 +6,7 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -25,6 +26,7 @@ CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync"
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
@@ -43,7 +45,7 @@ def confluence_doc_sync(
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.CONFLUENCE,
|
||||
slim_connector=confluence_connector,
|
||||
|
||||
294
backend/ee/onyx/external_permissions/github/doc_sync.py
Normal file
294
backend/ee/onyx/external_permissions/github/doc_sync.py
Normal file
@@ -0,0 +1,294 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from github import Github
|
||||
from github.Repository import Repository
|
||||
|
||||
from ee.onyx.external_permissions.github.utils import fetch_repository_team_slugs
|
||||
from ee.onyx.external_permissions.github.utils import form_collaborators_group_id
|
||||
from ee.onyx.external_permissions.github.utils import form_organization_group_id
|
||||
from ee.onyx.external_permissions.github.utils import (
|
||||
form_outside_collaborators_group_id,
|
||||
)
|
||||
from ee.onyx.external_permissions.github.utils import get_external_access_permission
|
||||
from ee.onyx.external_permissions.github.utils import get_repository_visibility
|
||||
from ee.onyx.external_permissions.github.utils import GitHubVisibility
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.connector import DocMetadata
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GITHUB_DOC_SYNC_LABEL = "github_doc_sync"
|
||||
|
||||
|
||||
def github_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
Sync GitHub documents with external access permissions.
|
||||
|
||||
This function checks each repository for visibility/team changes and updates
|
||||
document permissions accordingly without using checkpoints.
|
||||
"""
|
||||
logger.info(f"Starting GitHub document sync for CC pair ID: {cc_pair.id}")
|
||||
|
||||
# Initialize GitHub connector with credentials
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
logger.info("GitHub connector credentials loaded successfully")
|
||||
|
||||
if not github_connector.github_client:
|
||||
logger.error("GitHub client initialization failed")
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
# Get all repositories from GitHub API
|
||||
logger.info("Fetching all repositories from GitHub API")
|
||||
try:
|
||||
repos = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(
|
||||
github_connector.github_client
|
||||
)
|
||||
else:
|
||||
# Single repository
|
||||
repos = [
|
||||
github_connector.get_github_repo(github_connector.github_client)
|
||||
]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
|
||||
logger.info(f"Found {len(repos)} repositories to check")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch repositories: {e}")
|
||||
raise
|
||||
|
||||
repo_to_doc_list_map: dict[str, list[DocumentRow]] = {}
|
||||
# sort order is ascending because we want to get the oldest documents first
|
||||
existing_docs: list[DocumentRow] = fetch_all_existing_docs_fn(
|
||||
sort_order=SortOrder.ASC
|
||||
)
|
||||
logger.info(f"Found {len(existing_docs)} documents to check")
|
||||
for doc in existing_docs:
|
||||
try:
|
||||
doc_metadata = DocMetadata.model_validate_json(json.dumps(doc.doc_metadata))
|
||||
if doc_metadata.repo not in repo_to_doc_list_map:
|
||||
repo_to_doc_list_map[doc_metadata.repo] = []
|
||||
repo_to_doc_list_map[doc_metadata.repo].append(doc)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse doc metadata: {e} for doc {doc.id}")
|
||||
continue
|
||||
logger.info(f"Found {len(repo_to_doc_list_map)} documents to check")
|
||||
# Process each repository individually
|
||||
for repo in repos:
|
||||
try:
|
||||
logger.info(f"Processing repository: {repo.id} (name: {repo.name})")
|
||||
repo_doc_list: list[DocumentRow] = repo_to_doc_list_map.get(
|
||||
repo.full_name, []
|
||||
)
|
||||
if not repo_doc_list:
|
||||
logger.warning(
|
||||
f"No documents found for repository {repo.id} ({repo.name})"
|
||||
)
|
||||
continue
|
||||
|
||||
current_external_group_ids = repo_doc_list[0].external_user_group_ids or []
|
||||
# Check if repository has any permission changes
|
||||
has_changes = _check_repository_for_changes(
|
||||
repo=repo,
|
||||
github_client=github_connector.github_client,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
)
|
||||
|
||||
if has_changes:
|
||||
logger.info(
|
||||
f"Repository {repo.id} ({repo.name}) has changes, updating documents"
|
||||
)
|
||||
|
||||
# Get new external access permissions for this repository
|
||||
new_external_access = get_external_access_permission(
|
||||
repo, github_connector.github_client
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Found {len(repo_doc_list)} documents for repository {repo.full_name}"
|
||||
)
|
||||
|
||||
# Yield updated external access for each document
|
||||
for doc in repo_doc_list:
|
||||
if callback:
|
||||
callback.progress(GITHUB_DOC_SYNC_LABEL, 1)
|
||||
|
||||
yield DocExternalAccess(
|
||||
doc_id=doc.id,
|
||||
external_access=new_external_access,
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Repository {repo.id} ({repo.name}) has no changes, skipping"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
|
||||
|
||||
logger.info(f"GitHub document sync completed for CC pair ID: {cc_pair.id}")
|
||||
|
||||
|
||||
def _check_repository_for_changes(
|
||||
repo: Repository,
|
||||
github_client: Github,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository has any permission changes (visibility or team updates).
|
||||
"""
|
||||
logger.info(f"Checking repository {repo.id} ({repo.name}) for changes")
|
||||
|
||||
# Check for repository visibility changes using the sample document data
|
||||
if _is_repo_visibility_changed_from_groups(
|
||||
repo=repo,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
):
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has visibility changes")
|
||||
return True
|
||||
|
||||
# Check for team membership changes if repository is private
|
||||
if get_repository_visibility(
|
||||
repo
|
||||
) == GitHubVisibility.PRIVATE and _teams_updated_from_groups(
|
||||
repo=repo,
|
||||
github_client=github_client,
|
||||
current_external_group_ids=current_external_group_ids,
|
||||
):
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has team changes")
|
||||
return True
|
||||
|
||||
logger.info(f"Repository {repo.id} ({repo.name}) has no changes")
|
||||
return False
|
||||
|
||||
|
||||
def _is_repo_visibility_changed_from_groups(
|
||||
repo: Repository,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository visibility has changed by analyzing existing external group IDs.
|
||||
|
||||
Args:
|
||||
repo: GitHub repository object
|
||||
current_external_group_ids: List of external group IDs from existing document
|
||||
|
||||
Returns:
|
||||
True if visibility has changed
|
||||
"""
|
||||
current_repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(f"Current repository visibility: {current_repo_visibility.value}")
|
||||
|
||||
# Build expected group IDs for current visibility
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_collaborators_group_id(repo.id),
|
||||
)
|
||||
|
||||
org_group_id = None
|
||||
if repo.organization:
|
||||
org_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_organization_group_id(repo.organization.id),
|
||||
)
|
||||
|
||||
# Determine existing visibility from group IDs
|
||||
has_collaborators_group = collaborators_group_id in current_external_group_ids
|
||||
has_org_group = org_group_id and org_group_id in current_external_group_ids
|
||||
|
||||
if has_collaborators_group:
|
||||
existing_repo_visibility = GitHubVisibility.PRIVATE
|
||||
elif has_org_group:
|
||||
existing_repo_visibility = GitHubVisibility.INTERNAL
|
||||
else:
|
||||
existing_repo_visibility = GitHubVisibility.PUBLIC
|
||||
|
||||
logger.info(f"Inferred existing visibility: {existing_repo_visibility.value}")
|
||||
|
||||
visibility_changed = existing_repo_visibility != current_repo_visibility
|
||||
if visibility_changed:
|
||||
logger.info(
|
||||
f"Visibility changed for repo {repo.id} ({repo.name}): "
|
||||
f"{existing_repo_visibility.value} -> {current_repo_visibility.value}"
|
||||
)
|
||||
|
||||
return visibility_changed
|
||||
|
||||
|
||||
def _teams_updated_from_groups(
|
||||
repo: Repository,
|
||||
github_client: Github,
|
||||
current_external_group_ids: list[str],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if repository team memberships have changed using existing group IDs.
|
||||
"""
|
||||
# Fetch current team slugs for the repository
|
||||
current_teams = fetch_repository_team_slugs(repo=repo, github_client=github_client)
|
||||
logger.info(
|
||||
f"Current teams for repository {repo.id} (name: {repo.name}): {current_teams}"
|
||||
)
|
||||
|
||||
# Build group IDs to exclude from team comparison (non-team groups)
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_collaborators_group_id(repo.id),
|
||||
)
|
||||
outside_collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=form_outside_collaborators_group_id(repo.id),
|
||||
)
|
||||
non_team_group_ids = {collaborators_group_id, outside_collaborators_group_id}
|
||||
|
||||
# Extract existing team IDs from current external group IDs
|
||||
existing_team_ids = set()
|
||||
for group_id in current_external_group_ids:
|
||||
# Skip all non-team groups, keep only team groups
|
||||
if group_id not in non_team_group_ids:
|
||||
existing_team_ids.add(group_id)
|
||||
|
||||
# Note: existing_team_ids from DB are already prefixed (e.g., "github__team-slug")
|
||||
# but current_teams from API are raw team slugs, so we need to add the prefix
|
||||
current_team_ids = set()
|
||||
for team_slug in current_teams:
|
||||
team_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=team_slug,
|
||||
)
|
||||
current_team_ids.add(team_group_id)
|
||||
|
||||
logger.info(
|
||||
f"Existing team IDs: {existing_team_ids}, Current team IDs: {current_team_ids}"
|
||||
)
|
||||
|
||||
# Compare actual team IDs to detect changes
|
||||
teams_changed = current_team_ids != existing_team_ids
|
||||
if teams_changed:
|
||||
logger.info(
|
||||
f"Team changes detected for repo {repo.id} (name: {repo.name}): "
|
||||
f"existing={existing_team_ids}, current={current_team_ids}"
|
||||
)
|
||||
|
||||
return teams_changed
|
||||
46
backend/ee/onyx/external_permissions/github/group_sync.py
Normal file
46
backend/ee/onyx/external_permissions/github/group_sync.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from github import Repository
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.github.utils import get_external_user_group
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def github_group_sync(
|
||||
tenant_id: str,
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
) -> Generator[ExternalUserGroup, None, None]:
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
if not github_connector.github_client:
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
logger.info("Starting GitHub group sync...")
|
||||
repos: list[Repository.Repository] = []
|
||||
if github_connector.repositories:
|
||||
if "," in github_connector.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = github_connector.get_github_repos(github_connector.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [github_connector.get_github_repo(github_connector.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = github_connector.get_all_repos(github_connector.github_client)
|
||||
|
||||
for repo in repos:
|
||||
try:
|
||||
for external_group in get_external_user_group(
|
||||
repo, github_connector.github_client
|
||||
):
|
||||
logger.info(f"External group: {external_group}")
|
||||
yield external_group
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
|
||||
488
backend/ee/onyx/external_permissions/github/utils.py
Normal file
488
backend/ee/onyx/external_permissions/github/utils.py
Normal file
@@ -0,0 +1,488 @@
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
|
||||
from github import Github
|
||||
from github import RateLimitExceededException
|
||||
from github.GithubException import GithubException
|
||||
from github.NamedUser import NamedUser
|
||||
from github.Organization import Organization
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.Repository import Repository
|
||||
from github.Team import Team
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GitHubVisibility(Enum):
|
||||
"""GitHub repository visibility options."""
|
||||
|
||||
PUBLIC = "public"
|
||||
PRIVATE = "private"
|
||||
INTERNAL = "internal"
|
||||
|
||||
|
||||
MAX_RETRY_COUNT = 3
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# Higher-order function to wrap GitHub operations with retry and exception handling
|
||||
|
||||
|
||||
def _run_with_retry(
|
||||
operation: Callable[[], T],
|
||||
description: str,
|
||||
github_client: Github,
|
||||
retry_count: int = 0,
|
||||
) -> Optional[T]:
|
||||
"""Execute a GitHub operation with retry on rate limit and exception handling."""
|
||||
logger.debug(f"Starting operation '{description}', attempt {retry_count + 1}")
|
||||
try:
|
||||
result = operation()
|
||||
logger.debug(f"Operation '{description}' completed successfully")
|
||||
return result
|
||||
except RateLimitExceededException:
|
||||
if retry_count < MAX_RETRY_COUNT:
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
logger.warning(
|
||||
f"Rate limit exceeded while {description}. Retrying... "
|
||||
f"(attempt {retry_count + 1}/{MAX_RETRY_COUNT})"
|
||||
)
|
||||
return _run_with_retry(
|
||||
operation, description, github_client, retry_count + 1
|
||||
)
|
||||
else:
|
||||
error_msg = f"Max retries exceeded for {description}"
|
||||
logger.exception(error_msg)
|
||||
raise RuntimeError(error_msg)
|
||||
except GithubException as e:
|
||||
logger.warning(f"GitHub API error during {description}: {e}")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.exception(f"Unexpected error during {description}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
"""Represents a GitHub user with their basic information."""
|
||||
|
||||
login: str
|
||||
name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
|
||||
|
||||
class TeamInfo(BaseModel):
|
||||
"""Represents a GitHub team with its members."""
|
||||
|
||||
name: str
|
||||
slug: str
|
||||
members: List[UserInfo]
|
||||
|
||||
|
||||
def _fetch_organization_members(
|
||||
github_client: Github, org_name: str, retry_count: int = 0
|
||||
) -> List[UserInfo]:
|
||||
"""Fetch all organization members including owners and regular members."""
|
||||
org_members: List[UserInfo] = []
|
||||
logger.info(f"Fetching organization members for {org_name}")
|
||||
|
||||
org = _run_with_retry(
|
||||
lambda: github_client.get_organization(org_name),
|
||||
f"get organization {org_name}",
|
||||
github_client,
|
||||
)
|
||||
if not org:
|
||||
logger.error(f"Failed to fetch organization {org_name}")
|
||||
raise RuntimeError(f"Failed to fetch organization {org_name}")
|
||||
|
||||
member_objs: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: org.get_members(filter_="all"),
|
||||
f"get members for organization {org_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for member in member_objs:
|
||||
user_info = UserInfo(login=member.login, name=member.name, email=member.email)
|
||||
org_members.append(user_info)
|
||||
|
||||
logger.info(f"Fetched {len(org_members)} members for organization {org_name}")
|
||||
return org_members
|
||||
|
||||
|
||||
def _fetch_repository_teams_detailed(
|
||||
repo: Repository, github_client: Github, retry_count: int = 0
|
||||
) -> List[TeamInfo]:
|
||||
"""Fetch teams with access to the repository and their members."""
|
||||
teams_data: List[TeamInfo] = []
|
||||
logger.info(f"Fetching teams for repository {repo.full_name}")
|
||||
|
||||
team_objs: PaginatedList[Team] | list[Team] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_teams(),
|
||||
f"get teams for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for team in team_objs:
|
||||
logger.info(
|
||||
f"Processing team {team.name} (slug: {team.slug}) for repository {repo.full_name}"
|
||||
)
|
||||
|
||||
members: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: team.get_members(),
|
||||
f"get members for team {team.name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
team_members = []
|
||||
for m in members:
|
||||
user_info = UserInfo(login=m.login, name=m.name, email=m.email)
|
||||
team_members.append(user_info)
|
||||
|
||||
team_info = TeamInfo(name=team.name, slug=team.slug, members=team_members)
|
||||
teams_data.append(team_info)
|
||||
logger.info(f"Team {team.name} has {len(team_members)} members")
|
||||
|
||||
logger.info(f"Fetched {len(teams_data)} teams for repository {repo.full_name}")
|
||||
return teams_data
|
||||
|
||||
|
||||
def fetch_repository_team_slugs(
|
||||
repo: Repository, github_client: Github, retry_count: int = 0
|
||||
) -> List[str]:
|
||||
"""Fetch team slugs with access to the repository."""
|
||||
logger.info(f"Fetching team slugs for repository {repo.full_name}")
|
||||
teams_data: List[str] = []
|
||||
|
||||
team_objs: PaginatedList[Team] | list[Team] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_teams(),
|
||||
f"get teams for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for team in team_objs:
|
||||
teams_data.append(team.slug)
|
||||
|
||||
logger.info(f"Fetched {len(teams_data)} team slugs for repository {repo.full_name}")
|
||||
return teams_data
|
||||
|
||||
|
||||
def _get_collaborators_and_outside_collaborators(
|
||||
github_client: Github,
|
||||
repo: Repository,
|
||||
) -> Tuple[List[UserInfo], List[UserInfo]]:
|
||||
"""Fetch and categorize collaborators into regular and outside collaborators."""
|
||||
collaborators: List[UserInfo] = []
|
||||
outside_collaborators: List[UserInfo] = []
|
||||
logger.info(f"Fetching collaborators for repository {repo.full_name}")
|
||||
|
||||
repo_collaborators: PaginatedList[NamedUser] | list[NamedUser] = (
|
||||
_run_with_retry(
|
||||
lambda: repo.get_collaborators(),
|
||||
f"get collaborators for repository {repo.full_name}",
|
||||
github_client,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
for collaborator in repo_collaborators:
|
||||
is_outside = False
|
||||
|
||||
# Check if collaborator is outside the organization
|
||||
if repo.organization:
|
||||
org: Organization | None = _run_with_retry(
|
||||
lambda: github_client.get_organization(repo.organization.login),
|
||||
f"get organization {repo.organization.login}",
|
||||
github_client,
|
||||
)
|
||||
|
||||
if org is not None:
|
||||
org_obj = org
|
||||
membership = _run_with_retry(
|
||||
lambda: org_obj.has_in_members(collaborator),
|
||||
f"check membership for {collaborator.login} in org {org_obj.login}",
|
||||
github_client,
|
||||
)
|
||||
is_outside = membership is not None and not membership
|
||||
|
||||
info = UserInfo(
|
||||
login=collaborator.login, name=collaborator.name, email=collaborator.email
|
||||
)
|
||||
if repo.organization and is_outside:
|
||||
outside_collaborators.append(info)
|
||||
else:
|
||||
collaborators.append(info)
|
||||
|
||||
logger.info(
|
||||
f"Categorized {len(collaborators)} regular and {len(outside_collaborators)} outside collaborators for {repo.full_name}"
|
||||
)
|
||||
return collaborators, outside_collaborators
|
||||
|
||||
|
||||
def form_collaborators_group_id(repository_id: int) -> str:
|
||||
"""Generate group ID for repository collaborators."""
|
||||
if not repository_id:
|
||||
logger.exception("Repository ID is required to generate collaborators group ID")
|
||||
raise ValueError("Repository ID must be set to generate group ID.")
|
||||
group_id = f"{repository_id}_collaborators"
|
||||
return group_id
|
||||
|
||||
|
||||
def form_organization_group_id(organization_id: int) -> str:
|
||||
"""Generate group ID for organization using organization ID."""
|
||||
if not organization_id:
|
||||
logger.exception(
|
||||
"Organization ID is required to generate organization group ID"
|
||||
)
|
||||
raise ValueError("Organization ID must be set to generate group ID.")
|
||||
group_id = f"{organization_id}_organization"
|
||||
return group_id
|
||||
|
||||
|
||||
def form_outside_collaborators_group_id(repository_id: int) -> str:
|
||||
"""Generate group ID for outside collaborators."""
|
||||
if not repository_id:
|
||||
logger.exception(
|
||||
"Repository ID is required to generate outside collaborators group ID"
|
||||
)
|
||||
raise ValueError("Repository ID must be set to generate group ID.")
|
||||
group_id = f"{repository_id}_outside_collaborators"
|
||||
return group_id
|
||||
|
||||
|
||||
def get_repository_visibility(repo: Repository) -> GitHubVisibility:
|
||||
"""
|
||||
Get the visibility of a repository.
|
||||
Returns GitHubVisibility enum member.
|
||||
"""
|
||||
if hasattr(repo, "visibility"):
|
||||
visibility = repo.visibility
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} visibility from attribute: {visibility}"
|
||||
)
|
||||
try:
|
||||
return GitHubVisibility(visibility)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Unknown visibility '{visibility}' for repo {repo.full_name}, defaulting to private"
|
||||
)
|
||||
return GitHubVisibility.PRIVATE
|
||||
|
||||
logger.info(f"Repository {repo.full_name} is private")
|
||||
return GitHubVisibility.PRIVATE
|
||||
|
||||
|
||||
def get_external_access_permission(
|
||||
repo: Repository, github_client: Github, add_prefix: bool = False
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get the external access permission for a repository.
|
||||
Uses group-based permissions for efficiency and scalability.
|
||||
|
||||
add_prefix: When this method is called during the initial permission sync via the connector,
|
||||
the group ID isn't prefixed with the source while inserting the document record.
|
||||
So in that case, set add_prefix to True, allowing the method itself to handle
|
||||
prefixing. However, when the same method is invoked from doc_sync, our system
|
||||
already adds the prefix to the group ID while processing the ExternalAccess object.
|
||||
"""
|
||||
# We maintain collaborators, and outside collaborators as two separate groups
|
||||
# instead of adding individual user emails to ExternalAccess.external_user_emails for two reasons:
|
||||
# 1. Changes in repo collaborators (additions/removals) would require updating all documents.
|
||||
# 2. Repo permissions can change without updating the repo's updated_at timestamp,
|
||||
# forcing full permission syncs for all documents every time, which is inefficient.
|
||||
|
||||
repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(
|
||||
f"Generating ExternalAccess for {repo.full_name}: visibility={repo_visibility.value}"
|
||||
)
|
||||
|
||||
if repo_visibility == GitHubVisibility.PUBLIC:
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is public - allowing access to all users"
|
||||
)
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
is_public=True,
|
||||
)
|
||||
elif repo_visibility == GitHubVisibility.PRIVATE:
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is private - setting up restricted access"
|
||||
)
|
||||
|
||||
collaborators_group_id = form_collaborators_group_id(repo.id)
|
||||
outside_collaborators_group_id = form_outside_collaborators_group_id(repo.id)
|
||||
if add_prefix:
|
||||
collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=collaborators_group_id,
|
||||
)
|
||||
outside_collaborators_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=outside_collaborators_group_id,
|
||||
)
|
||||
group_ids = {collaborators_group_id, outside_collaborators_group_id}
|
||||
|
||||
team_slugs = fetch_repository_team_slugs(repo, github_client)
|
||||
if add_prefix:
|
||||
team_slugs = [
|
||||
build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=slug,
|
||||
)
|
||||
for slug in team_slugs
|
||||
]
|
||||
group_ids.update(team_slugs)
|
||||
|
||||
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
else:
|
||||
# Internal repositories - accessible to organization members
|
||||
logger.info(
|
||||
f"Repository {repo.full_name} is internal - accessible to org members"
|
||||
)
|
||||
org_group_id = form_organization_group_id(repo.organization.id)
|
||||
if add_prefix:
|
||||
org_group_id = build_ext_group_name_for_onyx(
|
||||
source=DocumentSource.GITHUB,
|
||||
ext_group_name=org_group_id,
|
||||
)
|
||||
group_ids = {org_group_id}
|
||||
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
|
||||
return ExternalAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=group_ids,
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
def get_external_user_group(
|
||||
repo: Repository, github_client: Github
|
||||
) -> list[ExternalUserGroup]:
|
||||
"""
|
||||
Get the external user group for a repository.
|
||||
Creates ExternalUserGroup objects with actual user emails for each permission group.
|
||||
"""
|
||||
repo_visibility = get_repository_visibility(repo)
|
||||
logger.info(
|
||||
f"Generating ExternalUserGroups for {repo.full_name}: visibility={repo_visibility.value}"
|
||||
)
|
||||
|
||||
if repo_visibility == GitHubVisibility.PRIVATE:
|
||||
logger.info(f"Processing private repository {repo.full_name}")
|
||||
|
||||
collaborators, outside_collaborators = (
|
||||
_get_collaborators_and_outside_collaborators(github_client, repo)
|
||||
)
|
||||
teams = _fetch_repository_teams_detailed(repo, github_client)
|
||||
external_user_groups = []
|
||||
|
||||
user_emails = set()
|
||||
for collab in collaborators:
|
||||
if collab.email:
|
||||
user_emails.add(collab.email)
|
||||
else:
|
||||
logger.error(f"Collaborator {collab.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
collaborators_group = ExternalUserGroup(
|
||||
id=form_collaborators_group_id(repo.id),
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(collaborators_group)
|
||||
logger.info(f"Created collaborators group with {len(user_emails)} emails")
|
||||
|
||||
# Create group for outside collaborators
|
||||
user_emails = set()
|
||||
for collab in outside_collaborators:
|
||||
if collab.email:
|
||||
user_emails.add(collab.email)
|
||||
else:
|
||||
logger.error(f"Outside collaborator {collab.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
outside_collaborators_group = ExternalUserGroup(
|
||||
id=form_outside_collaborators_group_id(repo.id),
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(outside_collaborators_group)
|
||||
logger.info(
|
||||
f"Created outside collaborators group with {len(user_emails)} emails"
|
||||
)
|
||||
|
||||
# Create groups for teams
|
||||
for team in teams:
|
||||
user_emails = set()
|
||||
for member in team.members:
|
||||
if member.email:
|
||||
user_emails.add(member.email)
|
||||
else:
|
||||
logger.error(f"Team member {member.login} has no email")
|
||||
|
||||
if user_emails:
|
||||
team_group = ExternalUserGroup(
|
||||
id=team.slug,
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
external_user_groups.append(team_group)
|
||||
logger.info(
|
||||
f"Created team group {team.name} with {len(user_emails)} emails"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Created {len(external_user_groups)} ExternalUserGroups for private repository {repo.full_name}"
|
||||
)
|
||||
return external_user_groups
|
||||
|
||||
if repo_visibility == GitHubVisibility.INTERNAL:
|
||||
logger.info(f"Processing internal repository {repo.full_name}")
|
||||
|
||||
org_group_id = form_organization_group_id(repo.organization.id)
|
||||
org_members = _fetch_organization_members(
|
||||
github_client, repo.organization.login
|
||||
)
|
||||
|
||||
user_emails = set()
|
||||
for member in org_members:
|
||||
if member.email:
|
||||
user_emails.add(member.email)
|
||||
else:
|
||||
logger.error(f"Org member {member.login} has no email")
|
||||
|
||||
org_group = ExternalUserGroup(
|
||||
id=org_group_id,
|
||||
user_emails=list(user_emails),
|
||||
)
|
||||
logger.info(
|
||||
f"Created organization group with {len(user_emails)} emails for internal repository {repo.full_name}"
|
||||
)
|
||||
return [org_group]
|
||||
|
||||
logger.info(f"Repository {repo.full_name} is public - no user groups needed")
|
||||
return []
|
||||
@@ -3,6 +3,7 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
@@ -35,6 +36,7 @@ def _get_slim_doc_generator(
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
|
||||
@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
get_permissions_by_ids,
|
||||
)
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
@@ -169,6 +170,7 @@ def get_external_access_for_raw_gdrive_file(
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -17,6 +18,7 @@ JIRA_DOC_SYNC_TAG = "jira_doc_sync"
|
||||
def jira_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
jira_connector = JiraConnector(
|
||||
@@ -26,7 +28,7 @@ def jira_doc_sync(
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.JIRA,
|
||||
slim_connector=jira_connector,
|
||||
|
||||
@@ -5,6 +5,8 @@ from typing import Protocol
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
|
||||
# Avoid circular imports
|
||||
if TYPE_CHECKING:
|
||||
@@ -15,14 +17,34 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class FetchAllDocumentsFunction(Protocol):
|
||||
"""Protocol for a function that fetches all document IDs for a connector credential pair."""
|
||||
"""Protocol for a function that fetches documents for a connector credential pair.
|
||||
|
||||
def __call__(self) -> list[str]:
|
||||
This protocol defines the interface for functions that retrieve documents
|
||||
from the database, typically used in permission synchronization workflows.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
sort_order: SortOrder | None,
|
||||
) -> list[DocumentRow]:
|
||||
"""
|
||||
Returns a list of document IDs for a connector credential pair.
|
||||
Fetches documents for a connector credential pair.
|
||||
"""
|
||||
...
|
||||
|
||||
This is typically used to determine which documents should no longer be
|
||||
accessible during the document sync process.
|
||||
|
||||
class FetchAllDocumentsIdsFunction(Protocol):
|
||||
"""Protocol for a function that fetches document IDs for a connector credential pair.
|
||||
|
||||
This protocol defines the interface for functions that retrieve document IDs
|
||||
from the database, typically used in permission synchronization workflows.
|
||||
"""
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Fetches document IDs for a connector credential pair.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -32,6 +54,7 @@ DocSyncFuncType = Callable[
|
||||
[
|
||||
"ConnectorCredentialPair",
|
||||
FetchAllDocumentsFunction,
|
||||
FetchAllDocumentsIdsFunction,
|
||||
Optional["IndexingHeartbeatInterface"],
|
||||
],
|
||||
Generator["DocExternalAccess", None, None],
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Generator
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
@@ -130,6 +131,7 @@ def _get_slack_document_access(
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
"""
|
||||
|
||||
@@ -7,12 +7,16 @@ from pydantic import BaseModel
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
|
||||
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
|
||||
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
|
||||
from ee.onyx.external_permissions.github.doc_sync import github_doc_sync
|
||||
from ee.onyx.external_permissions.github.group_sync import github_group_sync
|
||||
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
@@ -20,6 +24,7 @@ from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
|
||||
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
|
||||
from ee.onyx.external_permissions.salesforce.postprocessing import (
|
||||
censor_salesforce_chunks,
|
||||
@@ -63,6 +68,7 @@ class SyncConfig(BaseModel):
|
||||
def mock_doc_sync(
|
||||
cc_pair: "ConnectorCredentialPair",
|
||||
fetch_all_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: Optional["IndexingHeartbeatInterface"],
|
||||
) -> Generator["DocExternalAccess", None, None]:
|
||||
"""Mock doc sync function for testing - returns empty list since permissions are fetched during indexing"""
|
||||
@@ -117,6 +123,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
|
||||
initial_index_should_sync=False,
|
||||
),
|
||||
),
|
||||
DocumentSource.GITHUB: SyncConfig(
|
||||
doc_sync_config=DocSyncConfig(
|
||||
doc_sync_frequency=GITHUB_PERMISSION_DOC_SYNC_FREQUENCY,
|
||||
doc_sync_func=github_doc_sync,
|
||||
initial_index_should_sync=True,
|
||||
),
|
||||
group_sync_config=GroupSyncConfig(
|
||||
group_sync_frequency=GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY,
|
||||
group_sync_func=github_group_sync,
|
||||
group_sync_is_cc_pair_agnostic=False,
|
||||
),
|
||||
),
|
||||
DocumentSource.SALESFORCE: SyncConfig(
|
||||
censoring_config=CensoringConfig(
|
||||
chunk_censoring_func=censor_salesforce_chunks,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from ee.onyx.external_permissions.utils import generic_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -18,6 +19,7 @@ TEAMS_DOC_SYNC_LABEL = "teams_doc_sync"
|
||||
def teams_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> Generator[DocExternalAccess, None, None]:
|
||||
teams_connector = TeamsConnector(
|
||||
@@ -27,7 +29,7 @@ def teams_doc_sync(
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
|
||||
callback=callback,
|
||||
doc_source=DocumentSource.TEAMS,
|
||||
slim_connector=teams_connector,
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
|
||||
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -14,7 +14,7 @@ logger = setup_logger()
|
||||
|
||||
def generic_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
|
||||
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
doc_source: DocumentSource,
|
||||
slim_connector: SlimConnector,
|
||||
@@ -62,9 +62,9 @@ def generic_doc_sync(
|
||||
)
|
||||
|
||||
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id=}")
|
||||
existing_doc_ids = set(fetch_all_existing_docs_fn())
|
||||
existing_doc_ids: list[str] = fetch_all_existing_docs_ids_fn()
|
||||
|
||||
missing_doc_ids = existing_doc_ids - newly_fetched_doc_ids
|
||||
missing_doc_ids = set(existing_doc_ids) - newly_fetched_doc_ids
|
||||
|
||||
if not missing_doc_ids:
|
||||
return
|
||||
|
||||
@@ -206,7 +206,7 @@ def _handle_standard_answers(
|
||||
|
||||
restate_question_blocks = get_restate_blocks(
|
||||
msg=query_msg.message,
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
is_slash_command=message_info.is_slash_command,
|
||||
)
|
||||
|
||||
answer_blocks = build_standard_answer_blocks(
|
||||
|
||||
@@ -67,7 +67,7 @@ def generate_chat_messages_report(
|
||||
file_id = file_store.save_file(
|
||||
content=temp_file,
|
||||
display_name=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
@@ -99,7 +99,7 @@ def generate_user_report(
|
||||
file_id = file_store.save_file(
|
||||
content=temp_file,
|
||||
display_name=file_name,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_origin=FileOrigin.GENERATED_REPORT,
|
||||
file_type="text/csv",
|
||||
)
|
||||
|
||||
|
||||
@@ -231,10 +231,7 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
True if equivalent, False if not."""
|
||||
current_tasks = set(name for name, _ in schedule1)
|
||||
new_tasks = set(schedule2.keys())
|
||||
if current_tasks != new_tasks:
|
||||
return False
|
||||
|
||||
return True
|
||||
return current_tasks == new_tasks
|
||||
|
||||
|
||||
@beat_init.connect
|
||||
|
||||
@@ -32,7 +32,6 @@ from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
@@ -161,7 +160,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
RedisUserGroup.reset_all(r)
|
||||
RedisConnectorDelete.reset_all(r)
|
||||
RedisConnectorPrune.reset_all(r)
|
||||
RedisConnectorIndex.reset_all(r)
|
||||
RedisConnectorStop.reset_all(r)
|
||||
RedisConnectorPermissionSync.reset_all(r)
|
||||
RedisConnectorExternalGroupSync.reset_all(r)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
@@ -8,10 +10,12 @@ import httpx
|
||||
|
||||
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
|
||||
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
|
||||
from onyx.connectors.connector_runner import batched_doc_ids
|
||||
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import BaseConnector
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
@@ -22,12 +26,14 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
|
||||
|
||||
|
||||
def document_batch_to_ids(
|
||||
doc_batch: list[Document],
|
||||
) -> set[str]:
|
||||
return {doc.id for doc in doc_batch}
|
||||
doc_batch: Iterator[list[Document]],
|
||||
) -> Generator[set[str], None, None]:
|
||||
for doc_list in doc_batch:
|
||||
yield {doc.id for doc in doc_list}
|
||||
|
||||
|
||||
def extract_ids_from_runnable_connector(
|
||||
@@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
|
||||
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
|
||||
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
|
||||
|
||||
doc_batch_generator = None
|
||||
doc_batch_id_generator = None
|
||||
|
||||
if isinstance(runnable_connector, LoadConnector):
|
||||
doc_batch_generator = runnable_connector.load_from_state()
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.load_from_state()
|
||||
)
|
||||
elif isinstance(runnable_connector, PollConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
|
||||
doc_batch_id_generator = document_batch_to_ids(
|
||||
runnable_connector.poll_source(start=start, end=end)
|
||||
)
|
||||
elif isinstance(runnable_connector, CheckpointedConnector):
|
||||
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
end = datetime.now(timezone.utc).timestamp()
|
||||
checkpoint = runnable_connector.build_dummy_checkpoint()
|
||||
checkpoint_generator = runnable_connector.load_from_checkpoint(
|
||||
start=start, end=end, checkpoint=checkpoint
|
||||
)
|
||||
doc_batch_id_generator = batched_doc_ids(
|
||||
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
|
||||
)
|
||||
else:
|
||||
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
|
||||
|
||||
doc_batch_processing_func = document_batch_to_ids
|
||||
# this function is called per batch for rate limiting
|
||||
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
|
||||
return doc_batch_ids
|
||||
|
||||
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
|
||||
doc_batch_processing_func = rate_limit_builder(
|
||||
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
|
||||
)(document_batch_to_ids)
|
||||
for doc_batch in doc_batch_generator:
|
||||
)(lambda x: x)
|
||||
for doc_batch_ids in doc_batch_id_generator:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"extract_ids_from_runnable_connector: Stop signal detected"
|
||||
)
|
||||
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
|
||||
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
|
||||
|
||||
if callback:
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
|
||||
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
|
||||
|
||||
return all_connector_doc_ids
|
||||
|
||||
|
||||
@@ -193,12 +193,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
|
||||
task_logger.info(
|
||||
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
|
||||
)
|
||||
search_settings_list = get_all_search_settings(db_session)
|
||||
for search_settings in search_settings_list:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
search_settings.id
|
||||
)
|
||||
redis_connector_index.reset()
|
||||
|
||||
redis_connector.prune.reset()
|
||||
redis_connector.permissions.reset()
|
||||
redis_connector.external_group_sync.reset()
|
||||
|
||||
@@ -2,7 +2,6 @@ import multiprocessing
|
||||
import os
|
||||
import time
|
||||
import traceback
|
||||
from http import HTTPStatus
|
||||
from time import sleep
|
||||
|
||||
import sentry_sdk
|
||||
@@ -22,7 +21,7 @@ from onyx.background.celery.tasks.models import SimpleJobResult
|
||||
from onyx.background.indexing.job_client import SimpleJob
|
||||
from onyx.background.indexing.job_client import SimpleJobClient
|
||||
from onyx.background.indexing.job_client import SimpleJobException
|
||||
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
|
||||
from onyx.background.indexing.run_docfetching import run_docfetching_entrypoint
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
@@ -34,7 +33,6 @@ from onyx.db.index_attempt import mark_attempt_canceled
|
||||
from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
@@ -156,7 +154,6 @@ def _docfetching_task(
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector.new_index(search_settings_id)
|
||||
|
||||
# TODO: remove all fences, cause all signals to be set in postgres
|
||||
if redis_connector.delete.fenced:
|
||||
@@ -214,7 +211,7 @@ def _docfetching_task(
|
||||
)
|
||||
|
||||
# This is where the heavy/real work happens
|
||||
run_indexing_entrypoint(
|
||||
run_docfetching_entrypoint(
|
||||
app,
|
||||
index_attempt_id,
|
||||
tenant_id,
|
||||
@@ -261,7 +258,7 @@ def _docfetching_task(
|
||||
def process_job_result(
|
||||
job: SimpleJob,
|
||||
connector_source: str | None,
|
||||
redis_connector_index: RedisConnectorIndex,
|
||||
index_attempt_id: int,
|
||||
log_builder: ConnectorIndexingLogBuilder,
|
||||
) -> SimpleJobResult:
|
||||
result = SimpleJobResult()
|
||||
@@ -278,13 +275,11 @@ def process_job_result(
|
||||
|
||||
# In EKS, there is an edge case where successful tasks return exit
|
||||
# code 1 in the cloud due to the set_spawn_method not sticking.
|
||||
# We've since worked around this, but the following is a safe way to
|
||||
# work around this issue. Basically, we ignore the job error state
|
||||
# if the completion signal is OK.
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int:
|
||||
status_enum = HTTPStatus(status_int)
|
||||
if status_enum == HTTPStatus.OK:
|
||||
# Workaround: check that the total number of batches is set, since this only
|
||||
# happens when docfetching completed successfully
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if index_attempt and index_attempt.total_batches is not None:
|
||||
ignore_exitcode = True
|
||||
|
||||
if ignore_exitcode:
|
||||
@@ -458,9 +453,6 @@ def docfetching_proxy_task(
|
||||
)
|
||||
)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# Track the last time memory info was emitted
|
||||
last_memory_emit_time = 0.0
|
||||
|
||||
@@ -487,7 +479,7 @@ def docfetching_proxy_task(
|
||||
if job.done():
|
||||
try:
|
||||
result = process_job_result(
|
||||
job, result.connector_source, redis_connector_index, log_builder
|
||||
job, result.connector_source, index_attempt_id, log_builder
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
@@ -552,15 +544,20 @@ def docfetching_proxy_task(
|
||||
# print with exception
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
ctx.index_attempt_id,
|
||||
db_session,
|
||||
failure_reason=failure_reason,
|
||||
full_exception_trace=result.exception_str,
|
||||
)
|
||||
attempt = get_index_attempt(db_session, ctx.index_attempt_id)
|
||||
|
||||
# only mark failures if not already terminal,
|
||||
# otherwise we're overwriting potential real stack traces
|
||||
if attempt and not attempt.status.is_terminal():
|
||||
failure_reason = (
|
||||
f"Spawned task exceptioned: exit_code={result.exit_code}"
|
||||
)
|
||||
mark_attempt_failed(
|
||||
ctx.index_attempt_id,
|
||||
db_session,
|
||||
failure_reason=failure_reason,
|
||||
full_exception_trace=result.exception_str,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
log_builder.build(
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
@@ -16,6 +15,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
|
||||
@@ -66,6 +67,7 @@ from onyx.db.index_attempt import mark_attempt_failed
|
||||
from onyx.db.index_attempt import mark_attempt_partially_succeeded
|
||||
from onyx.db.index_attempt import mark_attempt_succeeded
|
||||
from onyx.db.indexing_coordination import CoordinationStatus
|
||||
from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
|
||||
from onyx.db.indexing_coordination import IndexingCoordination
|
||||
from onyx.db.models import IndexAttempt
|
||||
from onyx.db.search_settings import get_active_search_settings_list
|
||||
@@ -102,6 +104,7 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
USER_FILE_INDEXING_LIMIT = 100
|
||||
DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER = 4
|
||||
|
||||
|
||||
def _get_fence_validation_block_expiration() -> int:
|
||||
@@ -257,7 +260,7 @@ class ConnectorIndexingLogBuilder:
|
||||
|
||||
|
||||
def monitor_indexing_attempt_progress(
|
||||
attempt: IndexAttempt, tenant_id: str, db_session: Session
|
||||
attempt: IndexAttempt, tenant_id: str, db_session: Session, task: Task
|
||||
) -> None:
|
||||
"""
|
||||
TODO: rewrite this docstring
|
||||
@@ -316,7 +319,9 @@ def monitor_indexing_attempt_progress(
|
||||
|
||||
# Check task completion using Celery
|
||||
try:
|
||||
check_indexing_completion(attempt.id, coordination_status, storage, tenant_id)
|
||||
check_indexing_completion(
|
||||
attempt.id, coordination_status, storage, tenant_id, task
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to monitor document processing completion: "
|
||||
@@ -350,6 +355,7 @@ def check_indexing_completion(
|
||||
coordination_status: CoordinationStatus,
|
||||
storage: DocumentBatchStorage,
|
||||
tenant_id: str,
|
||||
task: Task,
|
||||
) -> None:
|
||||
|
||||
logger.info(
|
||||
@@ -376,20 +382,78 @@ def check_indexing_completion(
|
||||
|
||||
# Update progress tracking and check for stalls
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Update progress tracking
|
||||
stalled_timeout_hours = INDEXING_PROGRESS_TIMEOUT_HOURS
|
||||
# Index attempts that are waiting between docfetching and
|
||||
# docprocessing get a generous stalling timeout
|
||||
if batches_total is not None and batches_processed == 0:
|
||||
stalled_timeout_hours = (
|
||||
stalled_timeout_hours * DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER
|
||||
)
|
||||
|
||||
timed_out = not IndexingCoordination.update_progress_tracking(
|
||||
db_session, index_attempt_id, batches_processed
|
||||
db_session,
|
||||
index_attempt_id,
|
||||
batches_processed,
|
||||
timeout_hours=stalled_timeout_hours,
|
||||
)
|
||||
|
||||
# Check for stalls (3-6 hour timeout)
|
||||
if timed_out:
|
||||
logger.error(
|
||||
f"Indexing attempt {index_attempt_id} has been indexing for 3-6 hours without progress. "
|
||||
f"Marking it as failed."
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason="Stalled indexing"
|
||||
)
|
||||
# Check for stalls (3-6 hour timeout). Only applies to in-progress attempts.
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if attempt and timed_out:
|
||||
if attempt.status == IndexingStatus.IN_PROGRESS:
|
||||
logger.error(
|
||||
f"Indexing attempt {index_attempt_id} has been indexing for "
|
||||
f"{stalled_timeout_hours//2}-{stalled_timeout_hours} hours without progress. "
|
||||
f"Marking it as failed."
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt_id, db_session, failure_reason="Stalled indexing"
|
||||
)
|
||||
elif (
|
||||
attempt.status == IndexingStatus.NOT_STARTED and attempt.celery_task_id
|
||||
):
|
||||
# Check if the task exists in the celery queue
|
||||
# This handles the case where Redis dies after task creation but before task execution
|
||||
redis_celery = task.app.broker_connection().channel().client # type: ignore
|
||||
task_exists = celery_find_task(
|
||||
attempt.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
redis_celery,
|
||||
)
|
||||
unacked_task_ids = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, redis_celery
|
||||
)
|
||||
|
||||
if not task_exists and attempt.celery_task_id not in unacked_task_ids:
|
||||
# there is a race condition where the docfetching task has been taken off
|
||||
# the queues (i.e. started) but the indexing attempt still has a status of
|
||||
# Not Started because the switch to in progress takes like 0.1 seconds.
|
||||
# sleep a bit and confirm that the attempt is still not in progress.
|
||||
time.sleep(1)
|
||||
attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if attempt and attempt.status == IndexingStatus.NOT_STARTED:
|
||||
logger.error(
|
||||
f"Task {attempt.celery_task_id} attached to indexing attempt "
|
||||
f"{index_attempt_id} does not exist in the queue. "
|
||||
f"Marking indexing attempt as failed."
|
||||
)
|
||||
mark_attempt_failed(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
failure_reason="Task not in queue",
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"Indexing attempt {index_attempt_id} is {attempt.status}. 3-6 hours without heartbeat "
|
||||
"but task is in the queue. Likely underprovisioned docfetching worker."
|
||||
)
|
||||
# Update last progress time so we won't time out again for another 3 hours
|
||||
IndexingCoordination.update_progress_tracking(
|
||||
db_session,
|
||||
index_attempt_id,
|
||||
batches_processed,
|
||||
force_update_progress=True,
|
||||
)
|
||||
|
||||
# check again on the next check_for_indexing task
|
||||
# TODO: on the cloud this is currently 25 minutes at most, which
|
||||
@@ -449,15 +513,6 @@ def check_indexing_completion(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# TODO: make it so we don't need this (might already be true)
|
||||
redis_connector = RedisConnector(
|
||||
tenant_id, attempt.connector_credential_pair_id
|
||||
)
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
attempt.search_settings_id
|
||||
)
|
||||
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
|
||||
|
||||
# Clean up FileStore storage (still needed for document batches during transition)
|
||||
try:
|
||||
logger.info(f"Cleaning up storage after indexing completion: {storage}")
|
||||
@@ -811,7 +866,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
|
||||
for attempt in active_attempts:
|
||||
try:
|
||||
monitor_indexing_attempt_progress(attempt, tenant_id, db_session)
|
||||
monitor_indexing_attempt_progress(
|
||||
attempt, tenant_id, db_session, self
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(f"Error monitoring attempt {attempt.id}")
|
||||
|
||||
@@ -1085,12 +1142,8 @@ def _docprocessing_task(
|
||||
f"Index attempt {index_attempt_id} is not running, status {index_attempt.status}"
|
||||
)
|
||||
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt.search_settings.id
|
||||
)
|
||||
|
||||
cross_batch_db_lock: RedisLock = r.lock(
|
||||
redis_connector_index.db_lock_key,
|
||||
redis_connector.db_lock_key(index_attempt.search_settings.id),
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
@@ -1230,17 +1283,6 @@ def _docprocessing_task(
|
||||
f"attempt={index_attempt_id} "
|
||||
)
|
||||
|
||||
# on failure, signal completion with an error to unblock the watchdog
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(db_session, index_attempt_id)
|
||||
if index_attempt and index_attempt.search_settings:
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt.search_settings.id
|
||||
)
|
||||
redis_connector_index.set_generator_complete(
|
||||
HTTPStatus.INTERNAL_SERVER_ERROR.value
|
||||
)
|
||||
|
||||
raise
|
||||
finally:
|
||||
if per_batch_lock and per_batch_lock.owned():
|
||||
|
||||
@@ -47,7 +47,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import SyncStatus
|
||||
from onyx.db.enums import SyncType
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.tag import delete_orphan_tags__no_commit
|
||||
@@ -70,9 +69,9 @@ logger = setup_logger()
|
||||
def _get_pruning_block_expiration() -> int:
|
||||
"""
|
||||
Compute the expiration time for the pruning block signal.
|
||||
Base expiration is 3600 seconds (1 hour), multiplied by the beat multiplier only in MULTI_TENANT mode.
|
||||
Base expiration is 60 seconds (1 minute), multiplied by the beat multiplier only in MULTI_TENANT mode.
|
||||
"""
|
||||
base_expiration = 3600 # seconds
|
||||
base_expiration = 60 # seconds
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return base_expiration
|
||||
@@ -145,10 +144,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
last_pruned = cc_pair.connector.time_created
|
||||
|
||||
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
|
||||
if datetime.now(timezone.utc) < next_prune:
|
||||
return False
|
||||
|
||||
return True
|
||||
return datetime.now(timezone.utc) >= next_prune
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -280,6 +276,9 @@ def try_creating_prune_generator_task(
|
||||
if not ALLOW_SIMULTANEOUS_PRUNING:
|
||||
count = redis_connector.prune.get_active_task_count()
|
||||
if count > 0:
|
||||
logger.info(
|
||||
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} no simultaneous pruning allowed"
|
||||
)
|
||||
return None
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
@@ -293,6 +292,9 @@ def try_creating_prune_generator_task(
|
||||
|
||||
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
|
||||
if not acquired:
|
||||
logger.info(
|
||||
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} lock not acquired"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
@@ -516,9 +518,6 @@ def connector_pruning_generator_task(
|
||||
cc_pair.credential,
|
||||
)
|
||||
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
redis_connector.new_index(search_settings.id)
|
||||
|
||||
callback = PruneCallback(
|
||||
0,
|
||||
redis_connector,
|
||||
|
||||
@@ -226,8 +226,12 @@ def _check_connector_and_attempt_status(
|
||||
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
|
||||
|
||||
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
|
||||
error_str = ""
|
||||
if index_attempt_loop.error_msg:
|
||||
error_str = f" Original error: {index_attempt_loop.error_msg}"
|
||||
|
||||
raise RuntimeError(
|
||||
f"Index Attempt is not running, status is {index_attempt_loop.status}"
|
||||
f"Index Attempt is not running, status is {index_attempt_loop.status}.{error_str}"
|
||||
)
|
||||
|
||||
if index_attempt_loop.celery_task_id is None:
|
||||
@@ -832,7 +836,7 @@ def _run_indexing(
|
||||
)
|
||||
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
def run_docfetching_entrypoint(
|
||||
app: Celery,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str,
|
||||
@@ -1350,6 +1354,9 @@ def reissue_old_batches(
|
||||
)
|
||||
path_info = batch_storage.extract_path_info(batch_id)
|
||||
if path_info is None:
|
||||
logger.warning(
|
||||
f"Could not extract path info from batch {batch_id}, skipping"
|
||||
)
|
||||
continue
|
||||
if path_info.cc_pair_id != cc_pair_id:
|
||||
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")
|
||||
|
||||
@@ -359,6 +359,12 @@ POLL_CONNECTOR_OFFSET = 30 # Minutes overlap between poll windows
|
||||
# only very select connectors are enabled and admins cannot add other connector types
|
||||
ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
|
||||
|
||||
# If set to true, curators can only access and edit assistants that they created
|
||||
CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS = (
|
||||
os.environ.get("CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS", "").lower()
|
||||
== "true"
|
||||
)
|
||||
|
||||
# Some calls to get information on expert users are quite costly especially with rate limiting
|
||||
# Since experts are not used in the actual user experience, currently it is turned off
|
||||
# for some connectors
|
||||
|
||||
@@ -25,9 +25,32 @@ TimeRange = tuple[datetime, datetime]
|
||||
CT = TypeVar("CT", bound=ConnectorCheckpoint)
|
||||
|
||||
|
||||
def batched_doc_ids(
|
||||
checkpoint_connector_generator: CheckpointOutput[CT],
|
||||
batch_size: int,
|
||||
) -> Generator[set[str], None, None]:
|
||||
batch: set[str] = set()
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
batch.add(document.id)
|
||||
elif (
|
||||
failure and failure.failed_document and failure.failed_document.document_id
|
||||
):
|
||||
batch.add(failure.failed_document.document_id)
|
||||
|
||||
if len(batch) >= batch_size:
|
||||
yield batch
|
||||
batch = set()
|
||||
if len(batch) > 0:
|
||||
yield batch
|
||||
|
||||
|
||||
class CheckpointOutputWrapper(Generic[CT]):
|
||||
"""
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format.
|
||||
Wraps a CheckpointOutput generator to give things back in a more digestible format,
|
||||
specifically for Document outputs.
|
||||
The connector format is easier for the connector implementor (e.g. it enforces exactly
|
||||
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
|
||||
formats.
|
||||
@@ -131,7 +154,7 @@ class ConnectorRunner(Generic[CT]):
|
||||
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
|
||||
checkpoint_connector_generator
|
||||
):
|
||||
if document is not None:
|
||||
if document is not None and isinstance(document, Document):
|
||||
self.doc_batch.append(document)
|
||||
|
||||
if failure is not None:
|
||||
|
||||
@@ -222,11 +222,21 @@ class LocalFileConnector(LoadConnector):
|
||||
"""
|
||||
Connector that reads files from Postgres and yields Documents, including
|
||||
embedded image extraction without summarization.
|
||||
|
||||
file_locations are S3/Filestore UUIDs
|
||||
file_names are the names of the files
|
||||
"""
|
||||
|
||||
# Note: file_names is a required parameter, but should not break backwards compatibility.
|
||||
# If add_file_names migration is not run, old file connector configs will not have file_names.
|
||||
# This is fine because the configs are not re-used to instantiate the connector.
|
||||
# file_names is only used for display purposes in the UI and file_locations is used as a fallback.
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
file_names: list[
|
||||
str
|
||||
], # Must accept this parameter as connector_specific_config is unpacked as args
|
||||
zip_metadata: dict[str, Any],
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
@@ -260,7 +270,7 @@ class LocalFileConnector(LoadConnector):
|
||||
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
|
||||
continue
|
||||
|
||||
metadata = self._get_file_metadata(file_id)
|
||||
metadata = self._get_file_metadata(file_record.display_name)
|
||||
file_io = file_store.read_file(file_id=file_id, mode="b")
|
||||
new_docs = _process_file(
|
||||
file_id=file_id,
|
||||
@@ -282,7 +292,9 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[os.environ["TEST_FILE"]], zip_metadata={}
|
||||
file_locations=[os.environ["TEST_FILE"]],
|
||||
file_names=[os.environ["TEST_FILE"]],
|
||||
zip_metadata={},
|
||||
)
|
||||
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
|
||||
doc_batches = connector.load_from_state()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import copy
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
@@ -17,17 +16,22 @@ from github.Issue import Issue
|
||||
from github.NamedUser import NamedUser
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
from github.Requester import Requester
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import CheckpointedConnector
|
||||
from onyx.connectors.github.models import SerializedRepository
|
||||
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
|
||||
from onyx.connectors.github.utils import deserialize_repository
|
||||
from onyx.connectors.github.utils import get_external_access_permission
|
||||
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
|
||||
from onyx.connectors.interfaces import CheckpointOutput
|
||||
from onyx.connectors.interfaces import ConnectorCheckpoint
|
||||
from onyx.connectors.interfaces import ConnectorFailure
|
||||
@@ -46,17 +50,7 @@ CURSOR_LOG_FREQUENCY = 50
|
||||
_MAX_NUM_RATE_LIMIT_RETRIES = 5
|
||||
|
||||
ONE_DAY = timedelta(days=1)
|
||||
|
||||
|
||||
def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
sleep_time = github_client.get_rate_limit().core.reset.replace(
|
||||
tzinfo=timezone.utc
|
||||
) - datetime.now(tz=timezone.utc)
|
||||
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
|
||||
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
|
||||
time.sleep(sleep_time.seconds)
|
||||
|
||||
|
||||
SLIM_BATCH_SIZE = 100
|
||||
# Cases
|
||||
# X (from start) standard run, no fallback to cursor-based pagination
|
||||
# X (from start) standard run errors, fallback to cursor-based pagination
|
||||
@@ -72,6 +66,10 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
# checkpoint progress (no infinite loop)
|
||||
|
||||
|
||||
class DocMetadata(BaseModel):
|
||||
repo: str
|
||||
|
||||
|
||||
def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
|
||||
if "_PaginatedList__nextUrl" in pag_list.__dict__:
|
||||
return "_PaginatedList__nextUrl"
|
||||
@@ -190,7 +188,7 @@ def _get_batch_rate_limited(
|
||||
getattr(obj, "raw_data")
|
||||
yield from objs
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
yield from _get_batch_rate_limited(
|
||||
git_objs,
|
||||
page_num,
|
||||
@@ -232,12 +230,17 @@ def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
def _convert_pr_to_document(
|
||||
pull_request: PullRequest, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
sections=[
|
||||
TextSection(link=pull_request.html_url, text=pull_request.body or "")
|
||||
],
|
||||
external_access=repo_external_access,
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier=f"{pull_request.number}: {pull_request.title}",
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
@@ -248,6 +251,8 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
if pull_request.updated_at
|
||||
else None
|
||||
),
|
||||
# this metadata is used in perm sync
|
||||
doc_metadata=doc_metadata.model_dump(),
|
||||
metadata={
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
@@ -301,14 +306,21 @@ def _fetch_issue_comments(issue: Issue) -> str:
|
||||
return "\nComment: ".join(comment.body for comment in comments)
|
||||
|
||||
|
||||
def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
def _convert_issue_to_document(
|
||||
issue: Issue, repo_external_access: ExternalAccess | None
|
||||
) -> Document:
|
||||
repo_name = issue.repository.full_name if issue.repository else ""
|
||||
doc_metadata = DocMetadata(repo=repo_name)
|
||||
return Document(
|
||||
id=issue.html_url,
|
||||
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
|
||||
source=DocumentSource.GITHUB,
|
||||
external_access=repo_external_access,
|
||||
semantic_identifier=f"{issue.number}: {issue.title}",
|
||||
# updated_at is UTC time but is timezone unaware
|
||||
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
|
||||
# this metadata is used in perm sync
|
||||
doc_metadata=doc_metadata.model_dump(),
|
||||
metadata={
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
@@ -343,18 +355,6 @@ def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
)
|
||||
|
||||
|
||||
class SerializedRepository(BaseModel):
|
||||
# id is part of the raw_data as well, just pulled out for convenience
|
||||
id: int
|
||||
headers: dict[str, str | int]
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
|
||||
|
||||
class GithubConnectorStage(Enum):
|
||||
START = "start"
|
||||
PRS = "prs"
|
||||
@@ -394,7 +394,7 @@ def make_cursor_url_callback(
|
||||
return cursor_url_callback
|
||||
|
||||
|
||||
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
@@ -423,7 +423,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_github_repo(
|
||||
def get_github_repo(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> Repository.Repository:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
@@ -434,10 +434,10 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
try:
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repo(github_client, attempt_num + 1)
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def _get_github_repos(
|
||||
def get_github_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
"""Get specific repositories based on comma-separated repo_name string."""
|
||||
@@ -465,10 +465,10 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
|
||||
return repos
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repos(github_client, attempt_num + 1)
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_github_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _get_all_repos(
|
||||
def get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
@@ -487,8 +487,8 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
user = github_client.get_user(self.repo_owner)
|
||||
return list(user.get_repos())
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_all_repos(github_client, attempt_num + 1)
|
||||
sleep_after_rate_limit_exception(github_client)
|
||||
return self.get_all_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _pull_requests_func(
|
||||
self, repo: Repository.Repository
|
||||
@@ -509,6 +509,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
include_permissions: bool = False,
|
||||
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
@@ -521,13 +522,13 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
repos = self.get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
repos = [self.get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
repos = self.get_all_repos(self.github_client)
|
||||
if not repos:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
@@ -547,28 +548,15 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
if checkpoint.cached_repo is None:
|
||||
raise ValueError("No repo saved in checkpoint")
|
||||
|
||||
# Try to access the requester - different PyGithub versions may use different attribute names
|
||||
try:
|
||||
# Try direct access to a known attribute name first
|
||||
if hasattr(self.github_client, "_requester"):
|
||||
requester = self.github_client._requester
|
||||
elif hasattr(self.github_client, "_Github__requester"):
|
||||
requester = self.github_client._Github__requester
|
||||
else:
|
||||
# If we can't find the requester attribute, we need to fall back to recreating the repo
|
||||
raise AttributeError("Could not find requester attribute")
|
||||
|
||||
repo = checkpoint.cached_repo.to_Repository(requester)
|
||||
except Exception as e:
|
||||
# If all else fails, re-fetch the repo directly
|
||||
logger.warning(
|
||||
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
|
||||
)
|
||||
repo_id = checkpoint.cached_repo.id
|
||||
repo = self.github_client.get_repo(repo_id)
|
||||
# Deserialize the repository from the checkpoint
|
||||
repo = deserialize_repository(checkpoint.cached_repo, self.github_client)
|
||||
|
||||
cursor_url_callback = make_cursor_url_callback(checkpoint)
|
||||
|
||||
repo_external_access: ExternalAccess | None = None
|
||||
if include_permissions:
|
||||
repo_external_access = get_external_access_permission(
|
||||
repo, self.github_client
|
||||
)
|
||||
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
|
||||
logger.info(f"Fetching PRs for repo: {repo.name}")
|
||||
|
||||
@@ -603,7 +591,9 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
):
|
||||
continue
|
||||
try:
|
||||
yield _convert_pr_to_document(cast(PullRequest, pr))
|
||||
yield _convert_pr_to_document(
|
||||
cast(PullRequest, pr), repo_external_access
|
||||
)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting PR to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
@@ -653,6 +643,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
self.github_client,
|
||||
)
|
||||
)
|
||||
logger.info(f"Fetched {len(issue_batch)} issues for repo: {repo.name}")
|
||||
checkpoint.curr_page += 1
|
||||
done_with_issues = False
|
||||
num_issues = 0
|
||||
@@ -678,7 +669,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
continue
|
||||
|
||||
try:
|
||||
yield _convert_issue_to_document(issue)
|
||||
yield _convert_issue_to_document(issue, repo_external_access)
|
||||
except Exception as e:
|
||||
error_msg = f"Error converting issue to document: {e}"
|
||||
logger.exception(error_msg)
|
||||
@@ -715,12 +706,16 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
checkpoint.stage = GithubConnectorStage.PRS
|
||||
checkpoint.reset()
|
||||
|
||||
logger.info(f"{len(checkpoint.cached_repo_ids)} repos remaining")
|
||||
if checkpoint.cached_repo_ids:
|
||||
logger.info(
|
||||
f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})"
|
||||
)
|
||||
else:
|
||||
logger.info("No more repos remaining")
|
||||
|
||||
return checkpoint
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
def _load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
@@ -741,7 +736,32 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
adjusted_start_datetime = epoch
|
||||
|
||||
return self._fetch_from_github(
|
||||
checkpoint, start=adjusted_start_datetime, end=end_datetime
|
||||
checkpoint,
|
||||
start=adjusted_start_datetime,
|
||||
end=end_datetime,
|
||||
include_permissions=include_permissions,
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=False
|
||||
)
|
||||
|
||||
@override
|
||||
def load_from_checkpoint_with_perm_sync(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
checkpoint: GithubConnectorCheckpoint,
|
||||
) -> CheckpointOutput[GithubConnectorCheckpoint]:
|
||||
return self._load_from_checkpoint(
|
||||
start, end, checkpoint, include_permissions=True
|
||||
)
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
@@ -775,6 +795,9 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{repo_name}"
|
||||
)
|
||||
logger.info(
|
||||
f"Successfully accessed repository: {self.repo_owner}/{repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
valid_repos = True
|
||||
# If at least one repo is valid, we can proceed
|
||||
@@ -882,7 +905,6 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
|
||||
|
||||
if __name__ == "__main__":
|
||||
import os
|
||||
from onyx.connectors.connector_runner import ConnectorRunner
|
||||
|
||||
# Initialize the connector
|
||||
connector = GithubConnector(
|
||||
@@ -893,6 +915,12 @@ if __name__ == "__main__":
|
||||
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
|
||||
)
|
||||
|
||||
if connector.github_client:
|
||||
get_external_access_permission(
|
||||
connector.get_github_repos(connector.github_client).pop(),
|
||||
connector.github_client,
|
||||
)
|
||||
|
||||
# Create a time range from epoch to now
|
||||
end_time = datetime.now(timezone.utc)
|
||||
start_time = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
17
backend/onyx/connectors/github/models.py
Normal file
17
backend/onyx/connectors/github/models.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from typing import Any
|
||||
|
||||
from github import Repository
|
||||
from github.Requester import Requester
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class SerializedRepository(BaseModel):
|
||||
# id is part of the raw_data as well, just pulled out for convenience
|
||||
id: int
|
||||
headers: dict[str, str | int]
|
||||
raw_data: dict[str, Any]
|
||||
|
||||
def to_Repository(self, requester: Requester) -> Repository.Repository:
|
||||
return Repository.Repository(
|
||||
requester, self.headers, self.raw_data, completed=True
|
||||
)
|
||||
25
backend/onyx/connectors/github/rate_limit_utils.py
Normal file
25
backend/onyx/connectors/github/rate_limit_utils.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from github import Github
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def sleep_after_rate_limit_exception(github_client: Github) -> None:
|
||||
"""
|
||||
Sleep until the GitHub rate limit resets.
|
||||
|
||||
Args:
|
||||
github_client: The GitHub client that hit the rate limit
|
||||
"""
|
||||
sleep_time = github_client.get_rate_limit().core.reset.replace(
|
||||
tzinfo=timezone.utc
|
||||
) - datetime.now(tz=timezone.utc)
|
||||
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
|
||||
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
|
||||
time.sleep(sleep_time.total_seconds())
|
||||
63
backend/onyx/connectors/github/utils.py
Normal file
63
backend/onyx/connectors/github/utils.py
Normal file
@@ -0,0 +1,63 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from github import Github
|
||||
from github.Repository import Repository
|
||||
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.github.models import SerializedRepository
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_external_access_permission(
|
||||
repo: Repository, github_client: Github
|
||||
) -> ExternalAccess:
|
||||
"""
|
||||
Get the external access permission for a repository.
|
||||
This functionality requires Enterprise Edition.
|
||||
"""
|
||||
# Check if EE is enabled
|
||||
if not global_version.is_ee_version():
|
||||
# For the MIT version, return an empty ExternalAccess (private document)
|
||||
return ExternalAccess.empty()
|
||||
|
||||
# Fetch the EE implementation
|
||||
ee_get_external_access_permission = cast(
|
||||
Callable[[Repository, Github, bool], ExternalAccess],
|
||||
fetch_versioned_implementation(
|
||||
"onyx.external_permissions.github.utils",
|
||||
"get_external_access_permission",
|
||||
),
|
||||
)
|
||||
|
||||
return ee_get_external_access_permission(repo, github_client, True)
|
||||
|
||||
|
||||
def deserialize_repository(
|
||||
cached_repo: SerializedRepository, github_client: Github
|
||||
) -> Repository:
|
||||
"""
|
||||
Deserialize a SerializedRepository back into a Repository object.
|
||||
"""
|
||||
# Try to access the requester - different PyGithub versions may use different attribute names
|
||||
try:
|
||||
# Try to get the requester using getattr to avoid linter errors
|
||||
requester = getattr(github_client, "_requester", None)
|
||||
if requester is None:
|
||||
requester = getattr(github_client, "_Github__requester", None)
|
||||
if requester is None:
|
||||
# If we can't find the requester attribute, we need to fall back to recreating the repo
|
||||
raise AttributeError("Could not find requester attribute")
|
||||
|
||||
return cached_repo.to_Repository(requester)
|
||||
except Exception as e:
|
||||
# If all else fails, re-fetch the repo directly
|
||||
logger.warning(
|
||||
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
|
||||
)
|
||||
repo_id = cached_repo.id
|
||||
return github_client.get_repo(repo_id)
|
||||
@@ -183,6 +183,7 @@ class DocumentBase(BaseModel):
|
||||
|
||||
# only filled in EE for connectors w/ permission sync enabled
|
||||
external_access: ExternalAccess | None = None
|
||||
doc_metadata: dict[str, Any] | None = None
|
||||
|
||||
def get_title_for_document_index(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import csv
|
||||
import gc
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
@@ -28,8 +29,12 @@ from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
|
||||
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
|
||||
from onyx.connectors.salesforce.utils import get_sqlite_db_path
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
@@ -38,27 +43,27 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
|
||||
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
|
||||
|
||||
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
|
||||
"Opportunity": {
|
||||
"Account": "account",
|
||||
ACCOUNT_OBJECT_TYPE: "account",
|
||||
"FiscalQuarter": "fiscal_quarter",
|
||||
"FiscalYear": "fiscal_year",
|
||||
"IsClosed": "is_closed",
|
||||
"Name": "name",
|
||||
NAME_FIELD: "name",
|
||||
"StageName": "stage_name",
|
||||
"Type": "type",
|
||||
"Amount": "amount",
|
||||
"CloseDate": "close_date",
|
||||
"Probability": "probability",
|
||||
"CreatedDate": "created_date",
|
||||
"LastModifiedDate": "last_modified_date",
|
||||
MODIFIED_FIELD: "last_modified_date",
|
||||
},
|
||||
"Contact": {
|
||||
"Account": "account",
|
||||
ACCOUNT_OBJECT_TYPE: "account",
|
||||
"CreatedDate": "created_date",
|
||||
"LastModifiedDate": "last_modified_date",
|
||||
MODIFIED_FIELD: "last_modified_date",
|
||||
},
|
||||
}
|
||||
|
||||
@@ -74,19 +79,77 @@ class SalesforceConnectorContext:
|
||||
parent_to_child_types: dict[str, set[str]] = {} # map from parent to child types
|
||||
child_to_parent_types: dict[str, set[str]] = {} # map from child to parent types
|
||||
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
|
||||
type_to_queryable_fields: dict[str, list[str]] = {}
|
||||
type_to_queryable_fields: dict[str, set[str]] = {}
|
||||
prefix_to_type: dict[str, str] = {} # infer the object type of an id immediately
|
||||
|
||||
parent_to_child_relationships: dict[str, set[str]] = (
|
||||
{}
|
||||
) # map from parent to child relationships
|
||||
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = (
|
||||
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = (
|
||||
{}
|
||||
) # map from relationship to queryable fields
|
||||
|
||||
parent_child_names_to_relationships: dict[str, str] = {}
|
||||
|
||||
|
||||
def _extract_fields_and_associations_from_config(
|
||||
config: dict[str, Any], object_type: str
|
||||
) -> tuple[list[str] | None, dict[str, list[str]]]:
|
||||
"""
|
||||
Extract fields and associations for a specific object type from custom config.
|
||||
|
||||
Returns:
|
||||
tuple of (fields_list, associations_dict)
|
||||
- fields_list: List of fields to query, or None if not specified (use all)
|
||||
- associations_dict: Dict mapping association names to their config
|
||||
"""
|
||||
if object_type not in config:
|
||||
return None, {}
|
||||
|
||||
obj_config = config[object_type]
|
||||
fields = obj_config.get("fields")
|
||||
associations = obj_config.get("associations", {})
|
||||
|
||||
return fields, associations
|
||||
|
||||
|
||||
def _validate_custom_query_config(config: dict[str, Any]) -> None:
|
||||
"""
|
||||
Validate the structure of the custom query configuration.
|
||||
"""
|
||||
|
||||
for object_type, obj_config in config.items():
|
||||
if not isinstance(obj_config, dict):
|
||||
raise ValueError(
|
||||
f"top level object {object_type} must be mapped to a dictionary"
|
||||
)
|
||||
|
||||
# Check if fields is a list when present
|
||||
if "fields" in obj_config:
|
||||
if not isinstance(obj_config["fields"], list):
|
||||
raise ValueError("if fields key exists, value must be a list")
|
||||
for v in obj_config["fields"]:
|
||||
if not isinstance(v, str):
|
||||
raise ValueError(f"if fields list value {v} is not a string")
|
||||
|
||||
# Check if associations is a dict when present
|
||||
if "associations" in obj_config:
|
||||
if not isinstance(obj_config["associations"], dict):
|
||||
raise ValueError(
|
||||
"if associations key exists, value must be a dictionary"
|
||||
)
|
||||
for assoc_name, assoc_fields in obj_config["associations"].items():
|
||||
if not isinstance(assoc_fields, list):
|
||||
raise ValueError(
|
||||
f"associations list value {assoc_fields} for key {assoc_name} is not a list"
|
||||
)
|
||||
for v in assoc_fields:
|
||||
if not isinstance(v, str):
|
||||
raise ValueError(
|
||||
f"if associations list value {v} is not a string"
|
||||
)
|
||||
|
||||
|
||||
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
"""Approach outline
|
||||
|
||||
@@ -134,14 +197,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
requested_objects: list[str] = [],
|
||||
custom_query_config: str | None = None,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._sf_client: OnyxSalesforce | None = None
|
||||
self.parent_object_list = (
|
||||
[obj.capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
# Validate and store custom query config
|
||||
if custom_query_config:
|
||||
config_json = json.loads(custom_query_config)
|
||||
self.custom_query_config: dict[str, Any] | None = config_json
|
||||
# If custom query config is provided, use the object types from it
|
||||
self.parent_object_list = list(config_json.keys())
|
||||
else:
|
||||
self.custom_query_config = None
|
||||
# Use the traditional requested_objects approach
|
||||
self.parent_object_list = (
|
||||
[obj.strip().capitalize() for obj in requested_objects]
|
||||
if requested_objects
|
||||
else _DEFAULT_PARENT_OBJECT_TYPES
|
||||
)
|
||||
|
||||
def load_credentials(
|
||||
self,
|
||||
@@ -187,7 +261,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
@staticmethod
|
||||
def _download_object_csvs(
|
||||
all_types_to_filter: dict[str, bool],
|
||||
queryable_fields_by_type: dict[str, list[str]],
|
||||
queryable_fields_by_type: dict[str, set[str]],
|
||||
directory: str,
|
||||
sf_client: OnyxSalesforce,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -325,9 +399,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# all_types.update(child_types.keys())
|
||||
|
||||
# # Always want to make sure user is grabbed for permissioning purposes
|
||||
# all_types.add("User")
|
||||
# all_types.add(USER_OBJECT_TYPE)
|
||||
# # Always want to make sure account is grabbed for reference purposes
|
||||
# all_types.add("Account")
|
||||
# all_types.add(ACCOUNT_OBJECT_TYPE)
|
||||
|
||||
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
@@ -351,7 +425,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# all_types.update(child_types)
|
||||
|
||||
# # Always want to make sure user is grabbed for permissioning purposes
|
||||
# all_types.add("User")
|
||||
# all_types.add(USER_OBJECT_TYPE)
|
||||
|
||||
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
@@ -364,7 +438,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
type_to_processed: dict[str, int] = {}
|
||||
|
||||
logger.info("_fetch_from_salesforce starting.")
|
||||
logger.info("_fetch_from_salesforce starting (full sync).")
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
@@ -548,7 +622,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
) -> GenerateDocumentsOutput:
|
||||
type_to_processed: dict[str, int] = {}
|
||||
|
||||
logger.info("_fetch_from_salesforce starting.")
|
||||
logger.info("_fetch_from_salesforce starting (delta sync).")
|
||||
if not self._sf_client:
|
||||
raise RuntimeError("self._sf_client is None!")
|
||||
|
||||
@@ -677,7 +751,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
try:
|
||||
last_modified_by_id = record["LastModifiedById"]
|
||||
user_record = self.sf_client.query_object(
|
||||
"User", last_modified_by_id, ctx.type_to_queryable_fields
|
||||
USER_OBJECT_TYPE,
|
||||
last_modified_by_id,
|
||||
ctx.type_to_queryable_fields,
|
||||
)
|
||||
if user_record:
|
||||
primary_owner = BasicExpertInfo.from_dict(user_record)
|
||||
@@ -792,7 +868,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = (
|
||||
{}
|
||||
) # for a given object, the fields reference parent objects
|
||||
type_to_queryable_fields: dict[str, list[str]] = {}
|
||||
type_to_queryable_fields: dict[str, set[str]] = {}
|
||||
prefix_to_type: dict[str, str] = {}
|
||||
|
||||
parent_to_child_relationships: dict[str, set[str]] = (
|
||||
@@ -802,15 +878,13 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# relationship keys are formatted as "parent__relationship"
|
||||
# we have to do this because relationship names are not unique!
|
||||
# values are a dict of relationship names to a list of queryable fields
|
||||
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = {}
|
||||
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = {}
|
||||
|
||||
parent_child_names_to_relationships: dict[str, str] = {}
|
||||
|
||||
full_sync = False
|
||||
if start is None and end is None:
|
||||
full_sync = True
|
||||
full_sync = start is None and end is None
|
||||
|
||||
# Step 1 - make a list of all the types to download (parent + direct child + "User")
|
||||
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
|
||||
# prefixes = {}
|
||||
|
||||
global_description = sf_client.describe()
|
||||
@@ -831,16 +905,58 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
|
||||
for parent_type in parent_types:
|
||||
# parent_onyx_sf_type = OnyxSalesforceType(parent_type, sf_client)
|
||||
type_to_queryable_fields[parent_type] = (
|
||||
sf_client.get_queryable_fields_by_type(parent_type)
|
||||
)
|
||||
|
||||
child_types_working = sf_client.get_children_of_sf_type(parent_type)
|
||||
logger.debug(
|
||||
f"Found {len(child_types_working)} child types for {parent_type}"
|
||||
)
|
||||
custom_fields: list[str] | None = []
|
||||
associations_config: dict[str, list[str]] | None = None
|
||||
|
||||
# parent_to_child_relationships[parent_type] = child_types_working
|
||||
# Set queryable fields for parent type
|
||||
if self.custom_query_config:
|
||||
custom_fields, associations_config = (
|
||||
_extract_fields_and_associations_from_config(
|
||||
self.custom_query_config, parent_type
|
||||
)
|
||||
)
|
||||
custom_fields = custom_fields or []
|
||||
|
||||
# Get custom fields for parent type
|
||||
field_set = set(custom_fields)
|
||||
# these are expected and used during doc conversion
|
||||
field_set.add(NAME_FIELD)
|
||||
field_set.add(MODIFIED_FIELD)
|
||||
|
||||
# Use only the specified fields
|
||||
type_to_queryable_fields[parent_type] = field_set
|
||||
logger.info(f"Using custom fields for {parent_type}: {field_set}")
|
||||
else:
|
||||
# Use all queryable fields
|
||||
type_to_queryable_fields[parent_type] = (
|
||||
sf_client.get_queryable_fields_by_type(parent_type)
|
||||
)
|
||||
logger.info(f"Using all fields for {parent_type}")
|
||||
|
||||
child_types_all = sf_client.get_children_of_sf_type(parent_type)
|
||||
logger.debug(f"Found {len(child_types_all)} child types for {parent_type}")
|
||||
logger.debug(f"child types: {child_types_all}")
|
||||
|
||||
child_types_working = child_types_all.copy()
|
||||
if associations_config is not None:
|
||||
child_types_working = {
|
||||
k: v for k, v in child_types_all.items() if k in associations_config
|
||||
}
|
||||
any_not_found = False
|
||||
for k in associations_config:
|
||||
if k not in child_types_working:
|
||||
any_not_found = True
|
||||
logger.warning(f"Association {k} not found in {parent_type}")
|
||||
if any_not_found:
|
||||
raise RuntimeError(
|
||||
f"Associations {associations_config} not found in {parent_type} "
|
||||
f"with child objects {child_types_all.keys()}"
|
||||
)
|
||||
|
||||
parent_to_child_relationships[parent_type] = set()
|
||||
parent_to_child_types[parent_type] = set()
|
||||
parent_to_relationship_queryable_fields[parent_type] = {}
|
||||
|
||||
for child_type, child_relationship in child_types_working.items():
|
||||
child_type = cast(str, child_type)
|
||||
@@ -848,8 +964,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
|
||||
|
||||
# map parent name to child name
|
||||
if parent_type not in parent_to_child_types:
|
||||
parent_to_child_types[parent_type] = set()
|
||||
parent_to_child_types[parent_type].add(child_type)
|
||||
|
||||
# reverse map child name to parent name
|
||||
@@ -858,19 +972,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
child_to_parent_types[child_type].add(parent_type)
|
||||
|
||||
# map parent name to child relationship
|
||||
if parent_type not in parent_to_child_relationships:
|
||||
parent_to_child_relationships[parent_type] = set()
|
||||
parent_to_child_relationships[parent_type].add(child_relationship)
|
||||
|
||||
# map relationship to queryable fields of the target table
|
||||
queryable_fields = sf_client.get_queryable_fields_by_type(child_type)
|
||||
if config_fields := (
|
||||
associations_config and associations_config.get(child_type)
|
||||
):
|
||||
field_set = set(config_fields)
|
||||
# these are expected and used during doc conversion
|
||||
field_set.add(NAME_FIELD)
|
||||
field_set.add(MODIFIED_FIELD)
|
||||
queryable_fields = field_set
|
||||
else:
|
||||
queryable_fields = sf_client.get_queryable_fields_by_type(
|
||||
child_type
|
||||
)
|
||||
|
||||
if child_relationship in parent_to_relationship_queryable_fields:
|
||||
raise RuntimeError(f"{child_relationship=} already exists")
|
||||
|
||||
if parent_type not in parent_to_relationship_queryable_fields:
|
||||
parent_to_relationship_queryable_fields[parent_type] = {}
|
||||
|
||||
parent_to_relationship_queryable_fields[parent_type][
|
||||
child_relationship
|
||||
] = queryable_fields
|
||||
@@ -894,14 +1014,22 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
all_types.update(child_types)
|
||||
|
||||
# NOTE(rkuo): should this be an implicit parent type?
|
||||
all_types.add("User") # Always add User for permissioning purposes
|
||||
all_types.add("Account") # Always add Account for reference purposes
|
||||
all_types.add(USER_OBJECT_TYPE) # Always add User for permissioning purposes
|
||||
all_types.add(ACCOUNT_OBJECT_TYPE) # Always add Account for reference purposes
|
||||
|
||||
logger.info(f"All object types: num={len(all_types)} list={all_types}")
|
||||
|
||||
# Ensure User and Account have queryable fields if they weren't already processed
|
||||
essential_types = [USER_OBJECT_TYPE, ACCOUNT_OBJECT_TYPE]
|
||||
for essential_type in essential_types:
|
||||
if essential_type not in type_to_queryable_fields:
|
||||
type_to_queryable_fields[essential_type] = (
|
||||
sf_client.get_queryable_fields_by_type(essential_type)
|
||||
)
|
||||
|
||||
# 1.1 - Detect all fields in child types which reference a parent type.
|
||||
# build dicts to detect relationships between parent and child
|
||||
for child_type in child_types:
|
||||
for child_type in child_types.union(essential_types):
|
||||
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
|
||||
parent_reference_fields = sf_client.get_parent_reference_fields(
|
||||
child_type, parent_types
|
||||
@@ -1003,6 +1131,32 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
yield doc_metadata_list
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
Validate that the Salesforce credentials and connector settings are correct.
|
||||
Specifically checks that we can make an authenticated request to Salesforce.
|
||||
"""
|
||||
|
||||
try:
|
||||
# Attempt to fetch a small batch of objects (arbitrary endpoint) to verify credentials
|
||||
self.sf_client.describe()
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Failed to validate Salesforce credentials. Please check your"
|
||||
f"credentials and try again. Error: {e}"
|
||||
)
|
||||
|
||||
if self.custom_query_config:
|
||||
try:
|
||||
_validate_custom_query_config(self.custom_query_config)
|
||||
except Exception as e:
|
||||
raise ConnectorMissingCredentialError(
|
||||
"Failed to validate Salesforce custom query config. Please check your"
|
||||
f"config and try again. Error: {e}"
|
||||
)
|
||||
|
||||
logger.info("Salesforce credentials validated successfully.")
|
||||
|
||||
# @override
|
||||
# def load_from_checkpoint(
|
||||
# self,
|
||||
@@ -1032,7 +1186,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = SalesforceConnector(requested_objects=["Account"])
|
||||
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE])
|
||||
|
||||
connector.load_credentials(
|
||||
{
|
||||
|
||||
@@ -10,6 +10,8 @@ from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -140,7 +142,7 @@ def _extract_primary_owner(
|
||||
first_name=user_data.get("FirstName"),
|
||||
last_name=user_data.get("LastName"),
|
||||
email=user_data.get("Email"),
|
||||
display_name=user_data.get("Name"),
|
||||
display_name=user_data.get(NAME_FIELD),
|
||||
)
|
||||
|
||||
# Check if all fields are None
|
||||
@@ -166,8 +168,8 @@ def convert_sf_query_result_to_doc(
|
||||
"""Generates a yieldable Document from query results"""
|
||||
|
||||
base_url = f"https://{sf_client.sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(record["LastModifiedDate"])
|
||||
extracted_semantic_identifier = record.get("Name", "Unknown Object")
|
||||
extracted_doc_updated_at = time_str_to_utc(record[MODIFIED_FIELD])
|
||||
extracted_semantic_identifier = record.get(NAME_FIELD, "Unknown Object")
|
||||
|
||||
sections = [_extract_section(record, f"{base_url}/{record_id}")]
|
||||
for child_record_key, child_record in child_records.items():
|
||||
@@ -205,8 +207,8 @@ def convert_sf_object_to_doc(
|
||||
salesforce_id = object_dict["Id"]
|
||||
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
|
||||
base_url = f"https://{sf_instance}"
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
|
||||
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
|
||||
extracted_doc_updated_at = time_str_to_utc(object_dict[MODIFIED_FIELD])
|
||||
extracted_semantic_identifier = object_dict.get(NAME_FIELD, "Unknown Object")
|
||||
|
||||
sections = [_extract_section(sf_object.data, f"{base_url}/{sf_object.id}")]
|
||||
for id in sf_db.get_child_ids(sf_object.id):
|
||||
|
||||
@@ -60,7 +60,7 @@ class OnyxSalesforce(Salesforce):
|
||||
return True
|
||||
|
||||
for suffix in SALESFORCE_BLACKLISTED_SUFFIXES:
|
||||
if object_type_lower.endswith(prefix):
|
||||
if object_type_lower.endswith(suffix):
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -112,7 +112,7 @@ class OnyxSalesforce(Salesforce):
|
||||
object_id: str,
|
||||
sf_type: str,
|
||||
child_relationships: list[str],
|
||||
relationships_to_fields: dict[str, list[str]],
|
||||
relationships_to_fields: dict[str, set[str]],
|
||||
) -> str:
|
||||
"""Returns a SOQL query given the object id, type and child relationships.
|
||||
|
||||
@@ -148,7 +148,7 @@ class OnyxSalesforce(Salesforce):
|
||||
self,
|
||||
object_type: str,
|
||||
object_id: str,
|
||||
type_to_queryable_fields: dict[str, list[str]],
|
||||
type_to_queryable_fields: dict[str, set[str]],
|
||||
) -> dict[str, Any] | None:
|
||||
record: dict[str, Any] = {}
|
||||
|
||||
@@ -172,7 +172,7 @@ class OnyxSalesforce(Salesforce):
|
||||
object_id: str,
|
||||
sf_type: str,
|
||||
child_relationships: list[str],
|
||||
relationships_to_fields: dict[str, list[str]],
|
||||
relationships_to_fields: dict[str, set[str]],
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""There's a limit on the number of subqueries we can put in a single query."""
|
||||
child_records: dict[str, dict[str, Any]] = {}
|
||||
@@ -264,10 +264,10 @@ class OnyxSalesforce(Salesforce):
|
||||
time.sleep(3)
|
||||
raise
|
||||
|
||||
def get_queryable_fields_by_type(self, name: str) -> list[str]:
|
||||
def get_queryable_fields_by_type(self, name: str) -> set[str]:
|
||||
object_description = self.describe_type(name)
|
||||
if object_description is None:
|
||||
return []
|
||||
return set()
|
||||
|
||||
fields: list[dict[str, Any]] = object_description["fields"]
|
||||
valid_fields: set[str] = set()
|
||||
@@ -286,7 +286,7 @@ class OnyxSalesforce(Salesforce):
|
||||
if field_name:
|
||||
valid_fields.add(field_name)
|
||||
|
||||
return list(valid_fields - field_names_to_remove)
|
||||
return valid_fields - field_names_to_remove
|
||||
|
||||
def get_children_of_sf_type(self, sf_type: str) -> dict[str, str]:
|
||||
"""Returns a dict of child object names to relationship names.
|
||||
|
||||
@@ -14,6 +14,7 @@ from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
|
||||
rate_limit_builder,
|
||||
)
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -54,12 +55,12 @@ def _build_created_date_time_filter_for_salesforce(
|
||||
|
||||
|
||||
def _make_time_filter_for_sf_type(
|
||||
queryable_fields: list[str],
|
||||
queryable_fields: set[str],
|
||||
start: SecondsSinceUnixEpoch,
|
||||
end: SecondsSinceUnixEpoch,
|
||||
) -> str | None:
|
||||
|
||||
if "LastModifiedDate" in queryable_fields:
|
||||
if MODIFIED_FIELD in queryable_fields:
|
||||
return _build_last_modified_time_filter_for_salesforce(start, end)
|
||||
|
||||
if "CreatedDate" in queryable_fields:
|
||||
@@ -69,14 +70,14 @@ def _make_time_filter_for_sf_type(
|
||||
|
||||
|
||||
def _make_time_filtered_query(
|
||||
queryable_fields: list[str], sf_type: str, time_filter: str
|
||||
queryable_fields: set[str], sf_type: str, time_filter: str
|
||||
) -> str:
|
||||
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
|
||||
return query
|
||||
|
||||
|
||||
def get_object_by_id_query(
|
||||
object_id: str, sf_type: str, queryable_fields: list[str]
|
||||
object_id: str, sf_type: str, queryable_fields: set[str]
|
||||
) -> str:
|
||||
query = (
|
||||
f"SELECT {', '.join(queryable_fields)} FROM {sf_type} WHERE Id = '{object_id}'"
|
||||
@@ -193,7 +194,7 @@ def _bulk_retrieve_from_salesforce(
|
||||
def fetch_all_csvs_in_parallel(
|
||||
sf_client: Salesforce,
|
||||
all_types_to_filter: dict[str, bool],
|
||||
queryable_fields_by_type: dict[str, list[str]],
|
||||
queryable_fields_by_type: dict[str, set[str]],
|
||||
start: SecondsSinceUnixEpoch | None,
|
||||
end: SecondsSinceUnixEpoch | None,
|
||||
target_dir: str,
|
||||
|
||||
@@ -8,11 +8,15 @@ from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import NAME_FIELD
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import validate_salesforce_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.utils import batch_list
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -567,7 +571,7 @@ class OnyxSalesforceSQLite:
|
||||
uncommitted_rows = 0
|
||||
|
||||
# If we're updating User objects, update the email map
|
||||
if object_type == "User":
|
||||
if object_type == USER_OBJECT_TYPE:
|
||||
OnyxSalesforceSQLite._update_user_email_map(cursor)
|
||||
|
||||
return updated_ids
|
||||
@@ -619,7 +623,7 @@ class OnyxSalesforceSQLite:
|
||||
with self._conn:
|
||||
cursor = self._conn.cursor()
|
||||
# Get the object data and account data
|
||||
if object_type == "Account" or isChild:
|
||||
if object_type == ACCOUNT_OBJECT_TYPE or isChild:
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
|
||||
)
|
||||
@@ -638,7 +642,7 @@ class OnyxSalesforceSQLite:
|
||||
|
||||
data = json.loads(result[0][0])
|
||||
|
||||
if object_type != "Account":
|
||||
if object_type != ACCOUNT_OBJECT_TYPE:
|
||||
|
||||
# convert any account ids of the relationships back into data fields, with name
|
||||
for row in result:
|
||||
@@ -647,14 +651,14 @@ class OnyxSalesforceSQLite:
|
||||
if len(row) < 3:
|
||||
continue
|
||||
|
||||
if row[1] and row[2] and row[2] == "Account":
|
||||
if row[1] and row[2] and row[2] == ACCOUNT_OBJECT_TYPE:
|
||||
data["AccountId"] = row[1]
|
||||
cursor.execute(
|
||||
"SELECT data FROM salesforce_objects WHERE id = ?",
|
||||
(row[1],),
|
||||
)
|
||||
account_data = json.loads(cursor.fetchone()[0])
|
||||
data["Account"] = account_data.get("Name", "")
|
||||
data[ACCOUNT_OBJECT_TYPE] = account_data.get(NAME_FIELD, "")
|
||||
|
||||
return SalesforceObject(id=object_id, type=object_type, data=data)
|
||||
|
||||
|
||||
@@ -2,6 +2,11 @@ import os
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
NAME_FIELD = "Name"
|
||||
MODIFIED_FIELD = "LastModifiedDate"
|
||||
ACCOUNT_OBJECT_TYPE = "Account"
|
||||
USER_OBJECT_TYPE = "User"
|
||||
|
||||
|
||||
@dataclass
|
||||
class SalesforceObject:
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import html
|
||||
import io
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.parse import unquote
|
||||
|
||||
import msal # type: ignore
|
||||
import requests
|
||||
from office365.graph_client import GraphClient # type: ignore
|
||||
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
|
||||
from office365.onedrive.sites.site import Site # type: ignore
|
||||
@@ -33,6 +37,10 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
ASPX_EXTENSION = ".aspx"
|
||||
REQUEST_TIMEOUT = 10
|
||||
|
||||
|
||||
class SiteDescriptor(BaseModel):
|
||||
"""Data class for storing SharePoint site information.
|
||||
|
||||
@@ -136,11 +144,156 @@ def _convert_driveitem_to_document(
|
||||
return doc
|
||||
|
||||
|
||||
def _convert_sitepage_to_document(
|
||||
site_page: dict[str, Any], site_name: str | None
|
||||
) -> Document:
|
||||
"""Convert a SharePoint site page to a Document object."""
|
||||
# Extract text content from the site page
|
||||
page_text = ""
|
||||
|
||||
# Get title and description
|
||||
title = cast(str, site_page.get("title", ""))
|
||||
description = cast(str, site_page.get("description", ""))
|
||||
|
||||
# Build the text content
|
||||
if title:
|
||||
page_text += f"# {title}\n\n"
|
||||
if description:
|
||||
page_text += f"{description}\n\n"
|
||||
|
||||
# Extract content from canvas layout if available
|
||||
canvas_layout = site_page.get("canvasLayout", {})
|
||||
if canvas_layout:
|
||||
horizontal_sections = canvas_layout.get("horizontalSections", [])
|
||||
for section in horizontal_sections:
|
||||
columns = section.get("columns", [])
|
||||
for column in columns:
|
||||
webparts = column.get("webparts", [])
|
||||
for webpart in webparts:
|
||||
# Extract text from different types of webparts
|
||||
webpart_type = webpart.get("@odata.type", "")
|
||||
|
||||
# Extract text from text webparts
|
||||
if webpart_type == "#microsoft.graph.textWebPart":
|
||||
inner_html = webpart.get("innerHtml", "")
|
||||
if inner_html:
|
||||
# Basic HTML to text conversion
|
||||
# Remove HTML tags but preserve some structure
|
||||
text_content = re.sub(r"<br\s*/?>", "\n", inner_html)
|
||||
text_content = re.sub(r"<li>", "• ", text_content)
|
||||
text_content = re.sub(r"</li>", "\n", text_content)
|
||||
text_content = re.sub(
|
||||
r"<h[1-6][^>]*>", "\n## ", text_content
|
||||
)
|
||||
text_content = re.sub(r"</h[1-6]>", "\n", text_content)
|
||||
text_content = re.sub(r"<p[^>]*>", "\n", text_content)
|
||||
text_content = re.sub(r"</p>", "\n", text_content)
|
||||
text_content = re.sub(r"<[^>]+>", "", text_content)
|
||||
# Decode HTML entities
|
||||
text_content = html.unescape(text_content)
|
||||
# Clean up extra whitespace
|
||||
text_content = re.sub(
|
||||
r"\n\s*\n", "\n\n", text_content
|
||||
).strip()
|
||||
if text_content:
|
||||
page_text += f"{text_content}\n\n"
|
||||
|
||||
# Extract text from standard webparts
|
||||
elif webpart_type == "#microsoft.graph.standardWebPart":
|
||||
data = webpart.get("data", {})
|
||||
|
||||
# Extract from serverProcessedContent
|
||||
server_content = data.get("serverProcessedContent", {})
|
||||
searchable_texts = server_content.get(
|
||||
"searchablePlainTexts", []
|
||||
)
|
||||
|
||||
for text_item in searchable_texts:
|
||||
if isinstance(text_item, dict):
|
||||
key = text_item.get("key", "")
|
||||
value = text_item.get("value", "")
|
||||
if value:
|
||||
# Add context based on key
|
||||
if key == "title":
|
||||
page_text += f"## {value}\n\n"
|
||||
else:
|
||||
page_text += f"{value}\n\n"
|
||||
|
||||
# Extract description if available
|
||||
description = data.get("description", "")
|
||||
if description:
|
||||
page_text += f"{description}\n\n"
|
||||
|
||||
# Extract title if available
|
||||
webpart_title = data.get("title", "")
|
||||
if webpart_title and webpart_title != description:
|
||||
page_text += f"## {webpart_title}\n\n"
|
||||
|
||||
page_text = page_text.strip()
|
||||
|
||||
# If no content extracted, use the title as fallback
|
||||
if not page_text and title:
|
||||
page_text = title
|
||||
|
||||
# Parse creation and modification info
|
||||
created_datetime = site_page.get("createdDateTime")
|
||||
if created_datetime:
|
||||
if isinstance(created_datetime, str):
|
||||
created_datetime = datetime.fromisoformat(
|
||||
created_datetime.replace("Z", "+00:00")
|
||||
)
|
||||
elif not created_datetime.tzinfo:
|
||||
created_datetime = created_datetime.replace(tzinfo=timezone.utc)
|
||||
|
||||
last_modified_datetime = site_page.get("lastModifiedDateTime")
|
||||
if last_modified_datetime:
|
||||
if isinstance(last_modified_datetime, str):
|
||||
last_modified_datetime = datetime.fromisoformat(
|
||||
last_modified_datetime.replace("Z", "+00:00")
|
||||
)
|
||||
elif not last_modified_datetime.tzinfo:
|
||||
last_modified_datetime = last_modified_datetime.replace(tzinfo=timezone.utc)
|
||||
|
||||
# Extract owner information
|
||||
primary_owners = []
|
||||
created_by = site_page.get("createdBy", {}).get("user", {})
|
||||
if created_by.get("displayName"):
|
||||
primary_owners.append(
|
||||
BasicExpertInfo(
|
||||
display_name=created_by.get("displayName"),
|
||||
email=created_by.get("email", ""),
|
||||
)
|
||||
)
|
||||
|
||||
web_url = site_page["webUrl"]
|
||||
semantic_identifier = cast(str, site_page.get("name", title))
|
||||
if semantic_identifier.endswith(ASPX_EXTENSION):
|
||||
semantic_identifier = semantic_identifier[: -len(ASPX_EXTENSION)]
|
||||
|
||||
doc = Document(
|
||||
id=site_page["id"],
|
||||
sections=[TextSection(link=web_url, text=page_text)],
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=last_modified_datetime or created_datetime,
|
||||
primary_owners=primary_owners,
|
||||
metadata=(
|
||||
{
|
||||
"site": site_name,
|
||||
}
|
||||
if site_name
|
||||
else {}
|
||||
),
|
||||
)
|
||||
return doc
|
||||
|
||||
|
||||
class SharepointConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
sites: list[str] = [],
|
||||
include_site_pages: bool = True,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self._graph_client: GraphClient | None = None
|
||||
@@ -148,6 +301,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
sites
|
||||
)
|
||||
self.msal_app: msal.ConfidentialClientApplication | None = None
|
||||
self.include_site_pages = include_site_pages
|
||||
|
||||
@property
|
||||
def graph_client(self) -> GraphClient:
|
||||
@@ -284,7 +438,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
|
||||
except Exception as e:
|
||||
# Some drives might not be accessible
|
||||
logger.warning(f"Failed to process drive: {str(e)}")
|
||||
logger.warning(f"Failed to process drive '{drive.name}': {str(e)}")
|
||||
|
||||
except Exception as e:
|
||||
err_str = str(e)
|
||||
@@ -327,6 +481,74 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
]
|
||||
return site_descriptors
|
||||
|
||||
def _fetch_site_pages(
|
||||
self,
|
||||
site_descriptor: SiteDescriptor,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Fetch SharePoint site pages (.aspx files) using the SharePoint Pages API."""
|
||||
# Get the site to extract the site ID
|
||||
site = self.graph_client.sites.get_by_url(site_descriptor.url)
|
||||
site.execute_query() # Execute the query to actually fetch the data
|
||||
site_id = site.id
|
||||
|
||||
# Get the token acquisition function from the GraphClient
|
||||
token_data = self._acquire_token()
|
||||
access_token = token_data.get("access_token")
|
||||
if not access_token:
|
||||
raise RuntimeError("Failed to acquire access token")
|
||||
|
||||
# Construct the SharePoint Pages API endpoint
|
||||
# Using API directly, since the Graph Client doesn't support the Pages API
|
||||
pages_endpoint = f"https://graph.microsoft.com/v1.0/sites/{site_id}/pages/microsoft.graph.sitePage"
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# Add expand parameter to get canvas layout content
|
||||
params = {"$expand": "canvasLayout"}
|
||||
|
||||
response = requests.get(
|
||||
pages_endpoint, headers=headers, params=params, timeout=REQUEST_TIMEOUT
|
||||
)
|
||||
response.raise_for_status()
|
||||
pages_data = response.json()
|
||||
all_pages = pages_data.get("value", [])
|
||||
|
||||
# Handle pagination if there are more pages
|
||||
while "@odata.nextLink" in pages_data:
|
||||
next_url = pages_data["@odata.nextLink"]
|
||||
response = requests.get(next_url, headers=headers, timeout=REQUEST_TIMEOUT)
|
||||
response.raise_for_status()
|
||||
pages_data = response.json()
|
||||
all_pages.extend(pages_data.get("value", []))
|
||||
|
||||
logger.debug(f"Found {len(all_pages)} site pages in {site_descriptor.url}")
|
||||
|
||||
# Filter pages based on time window if specified
|
||||
if start is not None or end is not None:
|
||||
filtered_pages = []
|
||||
for page in all_pages:
|
||||
page_modified = page.get("lastModifiedDateTime")
|
||||
if page_modified:
|
||||
if isinstance(page_modified, str):
|
||||
page_modified = datetime.fromisoformat(
|
||||
page_modified.replace("Z", "+00:00")
|
||||
)
|
||||
|
||||
if start is not None and page_modified < start:
|
||||
continue
|
||||
if end is not None and page_modified > end:
|
||||
continue
|
||||
|
||||
filtered_pages.append(page)
|
||||
all_pages = filtered_pages
|
||||
|
||||
return all_pages
|
||||
|
||||
def _fetch_from_sharepoint(
|
||||
self, start: datetime | None = None, end: datetime | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -335,6 +557,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
# goes over all urls, converts them into Document objects and then yields them in batches
|
||||
doc_batch: list[Document] = []
|
||||
for site_descriptor in site_descriptors:
|
||||
# Fetch regular documents from document libraries
|
||||
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
|
||||
for driveitem, drive_name in driveitems:
|
||||
logger.debug(f"Processing: {driveitem.web_url}")
|
||||
@@ -347,8 +570,47 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
# Fetch SharePoint site pages (.aspx files)
|
||||
# Only fetch site pages if a folder is not specified since this processing
|
||||
# happens at a site-wide level + specifying a folder implies that the
|
||||
# user probably isn't looking for site pages
|
||||
specified_path = (
|
||||
site_descriptor.folder_path is not None
|
||||
or site_descriptor.drive_name is not None
|
||||
)
|
||||
if self.include_site_pages and not specified_path:
|
||||
site_pages = self._fetch_site_pages(
|
||||
site_descriptor, start=start, end=end
|
||||
)
|
||||
for site_page in site_pages:
|
||||
logger.debug(
|
||||
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
|
||||
)
|
||||
doc_batch.append(
|
||||
_convert_sitepage_to_document(
|
||||
site_page, site_descriptor.drive_name
|
||||
)
|
||||
)
|
||||
|
||||
if len(doc_batch) >= self.batch_size:
|
||||
yield doc_batch
|
||||
doc_batch = []
|
||||
|
||||
yield doc_batch
|
||||
|
||||
def _acquire_token(self) -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
"""
|
||||
if self.msal_app is None:
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
)
|
||||
return token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
sp_client_id = credentials["sp_client_id"]
|
||||
sp_client_secret = credentials["sp_client_secret"]
|
||||
@@ -360,20 +622,7 @@ class SharepointConnector(LoadConnector, PollConnector):
|
||||
client_id=sp_client_id,
|
||||
client_credential=sp_client_secret,
|
||||
)
|
||||
|
||||
def _acquire_token_func() -> dict[str, Any]:
|
||||
"""
|
||||
Acquire token via MSAL
|
||||
"""
|
||||
if self.msal_app is None:
|
||||
raise RuntimeError("MSAL app is not initialized")
|
||||
|
||||
token = self.msal_app.acquire_token_for_client(
|
||||
scopes=["https://graph.microsoft.com/.default"]
|
||||
)
|
||||
return token
|
||||
|
||||
self._graph_client = GraphClient(_acquire_token_func)
|
||||
self._graph_client = GraphClient(self._acquire_token)
|
||||
return None
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
|
||||
@@ -48,7 +48,9 @@ from onyx.db.relationships import (
|
||||
delete_from_kg_relationships_extraction_staging__no_commit,
|
||||
)
|
||||
from onyx.db.tag import delete_document_tags_for_documents__no_commit
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import model_to_dict
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
@@ -150,7 +152,7 @@ def get_documents_for_cc_pair(
|
||||
|
||||
|
||||
def get_document_ids_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
db_session: Session, connector_id: int, credential_id: int
|
||||
) -> list[str]:
|
||||
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
@@ -161,6 +163,47 @@ def get_document_ids_for_connector_credential_pair(
|
||||
return list(db_session.execute(doc_ids_stmt).scalars().all())
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair_limited_columns(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
sort_order: SortOrder | None = None,
|
||||
) -> Sequence[DocumentRow]:
|
||||
|
||||
doc_ids_subquery = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
)
|
||||
doc_ids_subquery = doc_ids_subquery.join(
|
||||
DbDocument, DocumentByConnectorCredentialPair.id == DbDocument.id
|
||||
)
|
||||
|
||||
stmt = select(
|
||||
DbDocument.id, DbDocument.doc_metadata, DbDocument.external_user_group_ids
|
||||
)
|
||||
|
||||
stmt = stmt.where(DbDocument.id.in_(doc_ids_subquery))
|
||||
|
||||
if sort_order == SortOrder.ASC:
|
||||
stmt = stmt.order_by(DbDocument.last_modified.asc())
|
||||
elif sort_order == SortOrder.DESC:
|
||||
stmt = stmt.order_by(DbDocument.last_modified.desc())
|
||||
|
||||
rows = db_session.execute(stmt).mappings().all()
|
||||
|
||||
doc_rows: list[DocumentRow] = []
|
||||
for row in rows:
|
||||
doc_row = DocumentRow(
|
||||
id=row.id,
|
||||
doc_metadata=row.doc_metadata,
|
||||
external_user_group_ids=row.external_user_group_ids,
|
||||
)
|
||||
doc_rows.append(doc_row)
|
||||
return doc_rows
|
||||
|
||||
|
||||
def get_documents_for_connector_credential_pair(
|
||||
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
|
||||
) -> Sequence[DbDocument]:
|
||||
@@ -370,6 +413,7 @@ def upsert_documents(
|
||||
if doc.external_access
|
||||
else {}
|
||||
),
|
||||
doc_metadata=doc.doc_metadata,
|
||||
)
|
||||
)
|
||||
for doc in seen_documents.values()
|
||||
@@ -389,6 +433,7 @@ def upsert_documents(
|
||||
"external_user_emails": insert_stmt.excluded.external_user_emails,
|
||||
"external_user_group_ids": insert_stmt.excluded.external_user_group_ids,
|
||||
"is_public": insert_stmt.excluded.is_public,
|
||||
"doc_metadata": insert_stmt.excluded.doc_metadata,
|
||||
},
|
||||
)
|
||||
db_session.execute(on_conflict_stmt)
|
||||
@@ -1031,7 +1076,7 @@ def reset_all_document_kg_stages(db_session: Session) -> int:
|
||||
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
|
||||
|
||||
|
||||
def update_document_kg_stages(
|
||||
@@ -1054,7 +1099,7 @@ def update_document_kg_stages(
|
||||
result = db_session.execute(stmt)
|
||||
# The hasattr check is needed for type checking, even though rowcount
|
||||
# is guaranteed to exist at runtime for UPDATE operations
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0
|
||||
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
|
||||
|
||||
|
||||
def get_skipped_kg_documents(db_session: Session) -> list[str]:
|
||||
|
||||
@@ -12,6 +12,7 @@ from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import onyx.db.document as dbdocument
|
||||
from onyx.db.entity_type import UNGROUNDED_SOURCE_NAME
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import KGEntity
|
||||
from onyx.db.models import KGEntityExtractionStaging
|
||||
@@ -328,7 +329,13 @@ def get_entity_stats_by_grounded_source_name(
|
||||
.group_by(KGEntityType.grounded_source_name)
|
||||
.all()
|
||||
)
|
||||
|
||||
# `row.grounded_source_name` is NULLABLE in the database schema.
|
||||
# Thus, for all "ungrounded" entity-types, we use a default name.
|
||||
return {
|
||||
row.grounded_source_name: (row.last_updated, row.entities_count)
|
||||
(row.grounded_source_name or UNGROUNDED_SOURCE_NAME): (
|
||||
row.last_updated,
|
||||
row.entities_count,
|
||||
)
|
||||
for row in results
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ from onyx.kg.models import KGAttributeEntityOption
|
||||
from onyx.server.kg.models import EntityType
|
||||
|
||||
|
||||
_UNGROUNDED_SOURCE_NAME = "Ungrounded"
|
||||
UNGROUNDED_SOURCE_NAME = "Ungrounded"
|
||||
|
||||
|
||||
def get_entity_types_with_grounded_source_name(
|
||||
@@ -87,7 +87,7 @@ def get_configured_entity_types(db_session: Session) -> dict[str, list[KGEntityT
|
||||
|
||||
et_map = defaultdict(list)
|
||||
for et in ets:
|
||||
key = et.grounded_source_name or _UNGROUNDED_SOURCE_NAME
|
||||
key = et.grounded_source_name or UNGROUNDED_SOURCE_NAME
|
||||
et_map[key].append(et)
|
||||
|
||||
return et_map
|
||||
|
||||
@@ -267,6 +267,7 @@ class IndexingCoordination:
|
||||
index_attempt_id: int,
|
||||
current_batches_completed: int,
|
||||
timeout_hours: int = INDEXING_PROGRESS_TIMEOUT_HOURS,
|
||||
force_update_progress: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Update progress tracking for stall detection.
|
||||
@@ -281,7 +282,8 @@ class IndexingCoordination:
|
||||
current_time = get_db_current_time(db_session)
|
||||
|
||||
# No progress - check if this is the first time tracking
|
||||
if attempt.last_progress_time is None:
|
||||
# or if the caller wants to simulate guaranteed progress
|
||||
if attempt.last_progress_time is None or force_update_progress:
|
||||
# First time tracking - initialize
|
||||
attempt.last_progress_time = current_time
|
||||
attempt.last_batches_completed_count = current_batches_completed
|
||||
|
||||
@@ -608,6 +608,10 @@ class Document(Base):
|
||||
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
|
||||
doc_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
tags = relationship(
|
||||
"Tag",
|
||||
secondary=Document__Tag.__table__,
|
||||
|
||||
@@ -15,6 +15,7 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.app_configs import DISABLE_AUTH
|
||||
from onyx.configs.chat_configs import BING_API_KEY
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
@@ -96,6 +97,14 @@ def _add_user_filters(
|
||||
where_clause = Persona.is_public == True # noqa: E712
|
||||
return stmt.where(where_clause)
|
||||
|
||||
# If curator ownership restriction is enabled, curators can only access their own assistants
|
||||
if CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS and user.role in [
|
||||
UserRole.CURATOR,
|
||||
UserRole.GLOBAL_CURATOR,
|
||||
]:
|
||||
where_clause = (Persona.user_id == user.id) | (Persona.user_id.is_(None))
|
||||
return stmt.where(where_clause)
|
||||
|
||||
where_clause = User__UserGroup.user_id == user.id
|
||||
if user.role == UserRole.CURATOR and get_editable:
|
||||
where_clause &= User__UserGroup.is_curator == True # noqa: E712
|
||||
|
||||
@@ -115,6 +115,7 @@ def create_file_connector_credential(
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from psycopg2 import errorcodes
|
||||
from psycopg2 import OperationalError
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import inspect
|
||||
|
||||
from onyx.db.models import Base
|
||||
@@ -27,3 +29,14 @@ def is_retryable_sqlalchemy_error(exc: BaseException) -> bool:
|
||||
pgcode = getattr(getattr(exc, "orig", None), "pgcode", None)
|
||||
return pgcode in RETRYABLE_PG_CODES
|
||||
return False
|
||||
|
||||
|
||||
class DocumentRow(BaseModel):
|
||||
id: str
|
||||
doc_metadata: dict[str, Any]
|
||||
external_user_group_ids: list[str]
|
||||
|
||||
|
||||
class SortOrder(str, Enum):
|
||||
ASC = "asc"
|
||||
DESC = "desc"
|
||||
|
||||
@@ -91,6 +91,7 @@ class DocumentMetadata:
|
||||
from_ingestion_api: bool = False
|
||||
|
||||
external_access: ExternalAccess | None = None
|
||||
doc_metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -17,11 +17,11 @@ from typing import NamedTuple
|
||||
from zipfile import BadZipFile
|
||||
|
||||
import chardet
|
||||
import docx # type: ignore
|
||||
import openpyxl # type: ignore
|
||||
import pptx # type: ignore
|
||||
from docx import Document as DocxDocument
|
||||
from fastapi import UploadFile
|
||||
from markitdown import FileConversionException
|
||||
from markitdown import MarkItDown
|
||||
from markitdown import UnsupportedFormatException
|
||||
from PIL import Image
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
@@ -83,11 +83,6 @@ IMAGE_MEDIA_TYPES = [
|
||||
"image/webp",
|
||||
]
|
||||
|
||||
KNOWN_OPENPYXL_BUGS = [
|
||||
"Value must be either numerical or a string containing a wildcard",
|
||||
"File contains no valid workbook part",
|
||||
]
|
||||
|
||||
|
||||
class OnyxExtensionType(IntFlag):
|
||||
Plain = auto()
|
||||
@@ -149,6 +144,13 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
|
||||
)
|
||||
|
||||
|
||||
def to_bytesio(stream: IO[bytes]) -> BytesIO:
|
||||
if isinstance(stream, BytesIO):
|
||||
return stream
|
||||
data = stream.read() # consumes the stream!
|
||||
return BytesIO(data)
|
||||
|
||||
|
||||
def load_files_from_zip(
|
||||
zip_file_io: IO,
|
||||
ignore_macos_resource_fork_files: bool = True,
|
||||
@@ -305,19 +307,38 @@ def read_pdf_file(
|
||||
return "", metadata, []
|
||||
|
||||
|
||||
def extract_docx_images(docx_bytes: IO[Any]) -> list[tuple[bytes, str]]:
|
||||
"""
|
||||
Given the bytes of a docx file, extract all the images.
|
||||
Returns a list of tuples (image_bytes, image_name).
|
||||
"""
|
||||
out = []
|
||||
try:
|
||||
with zipfile.ZipFile(docx_bytes) as z:
|
||||
for name in z.namelist():
|
||||
if name.startswith("word/media/"):
|
||||
out.append((z.read(name), name.split("/")[-1]))
|
||||
except Exception:
|
||||
logger.exception("Failed to extract all docx images")
|
||||
return out
|
||||
|
||||
|
||||
def docx_to_text_and_images(
|
||||
file: IO[Any], file_name: str = ""
|
||||
) -> tuple[str, Sequence[tuple[bytes, str]]]:
|
||||
"""
|
||||
Extract text from a docx. If embed_images=True, also extract inline images.
|
||||
Extract text from a docx.
|
||||
Return (text_content, list_of_images).
|
||||
"""
|
||||
paragraphs = []
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
try:
|
||||
doc = docx.Document(file)
|
||||
except BadZipFile as e:
|
||||
doc = md.convert(to_bytesio(file))
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
) as e:
|
||||
logger.warning(
|
||||
f"Failed to extract docx {file_name or 'docx file'}: {e}. Attempting to read as text file."
|
||||
)
|
||||
@@ -330,86 +351,44 @@ def docx_to_text_and_images(
|
||||
)
|
||||
return text_content_raw or "", []
|
||||
|
||||
# Grab text from paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
paragraphs.append(paragraph.text)
|
||||
|
||||
# Reset position so we can re-load the doc (python-docx has read the stream)
|
||||
# Note: if python-docx has fully consumed the stream, you may need to open it again from memory.
|
||||
# For large docs, a more robust approach is needed.
|
||||
# This is a simplified example.
|
||||
|
||||
for rel_id, rel in doc.part.rels.items():
|
||||
if "image" in rel.reltype:
|
||||
# image is typically in rel.target_part.blob
|
||||
image_bytes = rel.target_part.blob
|
||||
image_name = rel.target_part.partname
|
||||
# store
|
||||
embedded_images.append((image_bytes, os.path.basename(str(image_name))))
|
||||
|
||||
text_content = "\n".join(paragraphs)
|
||||
return text_content, embedded_images
|
||||
file.seek(0)
|
||||
return doc.markdown, extract_docx_images(to_bytesio(file))
|
||||
|
||||
|
||||
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
try:
|
||||
presentation = pptx.Presentation(file)
|
||||
except BadZipFile as e:
|
||||
presentation = md.convert(to_bytesio(file))
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
) as e:
|
||||
error_str = f"Failed to extract text from {file_name or 'pptx file'}: {e}"
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
text_content = []
|
||||
for slide_number, slide in enumerate(presentation.slides, start=1):
|
||||
slide_text = f"\nSlide {slide_number}:\n"
|
||||
for shape in slide.shapes:
|
||||
if hasattr(shape, "text"):
|
||||
slide_text += shape.text + "\n"
|
||||
text_content.append(slide_text)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
return presentation.markdown
|
||||
|
||||
|
||||
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
md = MarkItDown(enable_plugins=False)
|
||||
try:
|
||||
workbook = openpyxl.load_workbook(file, read_only=True)
|
||||
except BadZipFile as e:
|
||||
workbook = md.convert(to_bytesio(file))
|
||||
except (
|
||||
BadZipFile,
|
||||
ValueError,
|
||||
FileConversionException,
|
||||
UnsupportedFormatException,
|
||||
) as e:
|
||||
error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
|
||||
if file_name.startswith("~"):
|
||||
logger.debug(error_str + " (this is expected for files with ~)")
|
||||
else:
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
except Exception as e:
|
||||
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
|
||||
logger.error(
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise e
|
||||
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
rows = []
|
||||
num_empty_consecutive_rows = 0
|
||||
for row in sheet.iter_rows(min_row=1, values_only=True):
|
||||
row_str = ",".join(str(cell or "") for cell in row)
|
||||
|
||||
# Only add the row if there are any values in the cells
|
||||
if len(row_str) >= len(row):
|
||||
rows.append(row_str)
|
||||
num_empty_consecutive_rows = 0
|
||||
else:
|
||||
num_empty_consecutive_rows += 1
|
||||
|
||||
if num_empty_consecutive_rows > 100:
|
||||
# handle massive excel sheets with mostly empty cells
|
||||
logger.warning(
|
||||
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
|
||||
" skipping rest of file"
|
||||
)
|
||||
break
|
||||
sheet_str = "\n".join(rows)
|
||||
text_content.append(sheet_str)
|
||||
return TEXT_SECTION_SEPARATOR.join(text_content)
|
||||
return workbook.markdown
|
||||
|
||||
|
||||
def eml_to_text(file: IO[Any]) -> str:
|
||||
@@ -462,9 +441,9 @@ def extract_file_text(
|
||||
"""
|
||||
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
|
||||
".pdf": pdf_to_text,
|
||||
".docx": lambda f: docx_to_text_and_images(f)[0], # no images
|
||||
".pptx": pptx_to_text,
|
||||
".xlsx": xlsx_to_text,
|
||||
".docx": lambda f: docx_to_text_and_images(f, file_name)[0], # no images
|
||||
".pptx": lambda f: pptx_to_text(f, file_name),
|
||||
".xlsx": lambda f: xlsx_to_text(f, file_name),
|
||||
".eml": eml_to_text,
|
||||
".epub": epub_to_text,
|
||||
".html": parse_html_page_basic,
|
||||
@@ -522,23 +501,28 @@ def extract_text_and_images(
|
||||
Primary new function for the updated connector.
|
||||
Returns structured extraction result with text content, embedded images, and metadata.
|
||||
"""
|
||||
file.seek(0)
|
||||
|
||||
try:
|
||||
# Attempt unstructured if env var is set
|
||||
if get_unstructured_api_key():
|
||||
# If the user doesn't want embedded images, unstructured is fine
|
||||
file.seek(0)
|
||||
if get_unstructured_api_key():
|
||||
try:
|
||||
text_content = unstructured_to_text(file, file_name)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=[], metadata={}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to process with Unstructured: {str(e)}. "
|
||||
"Falling back to normal processing."
|
||||
)
|
||||
file.seek(0) # Reset file pointer just in case
|
||||
|
||||
# Default processing
|
||||
try:
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
# docx example for embedded images
|
||||
if extension == ".docx":
|
||||
file.seek(0)
|
||||
text_content, images = docx_to_text_and_images(file)
|
||||
text_content, images = docx_to_text_and_images(file, file_name)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata={}
|
||||
)
|
||||
@@ -546,7 +530,6 @@ def extract_text_and_images(
|
||||
# PDF example: we do not show complicated PDF image extraction here
|
||||
# so we simply extract text for now and skip images.
|
||||
if extension == ".pdf":
|
||||
file.seek(0)
|
||||
text_content, pdf_metadata, images = read_pdf_file(
|
||||
file,
|
||||
pdf_pass,
|
||||
@@ -559,7 +542,6 @@ def extract_text_and_images(
|
||||
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
|
||||
# You can do something similar to docx if needed.
|
||||
if extension == ".pptx":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=pptx_to_text(file, file_name=file_name),
|
||||
embedded_images=[],
|
||||
@@ -567,7 +549,6 @@ def extract_text_and_images(
|
||||
)
|
||||
|
||||
if extension == ".xlsx":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=xlsx_to_text(file, file_name=file_name),
|
||||
embedded_images=[],
|
||||
@@ -575,19 +556,16 @@ def extract_text_and_images(
|
||||
)
|
||||
|
||||
if extension == ".eml":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=eml_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".epub":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=epub_to_text(file), embedded_images=[], metadata={}
|
||||
)
|
||||
|
||||
if extension == ".html":
|
||||
file.seek(0)
|
||||
return ExtractionResult(
|
||||
text_content=parse_html_page_basic(file),
|
||||
embedded_images=[],
|
||||
@@ -596,7 +574,6 @@ def extract_text_and_images(
|
||||
|
||||
# If we reach here and it's a recognized text extension
|
||||
if is_text_file_extension(file_name):
|
||||
file.seek(0)
|
||||
encoding = detect_encoding(file)
|
||||
text_content_raw, file_metadata = read_text_file(
|
||||
file, encoding=encoding, ignore_onyx_metadata=False
|
||||
@@ -609,7 +586,6 @@ def extract_text_and_images(
|
||||
|
||||
# If it's an image file or something else, we do not parse embedded images from them
|
||||
# just return empty text
|
||||
file.seek(0)
|
||||
return ExtractionResult(text_content="", embedded_images=[], metadata={})
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -32,9 +32,11 @@ def is_valid_image_type(mime_type: str) -> bool:
|
||||
Returns:
|
||||
True if the MIME type is a valid image type, False otherwise
|
||||
"""
|
||||
if not mime_type:
|
||||
return False
|
||||
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
return (
|
||||
bool(mime_type)
|
||||
and mime_type.startswith("image/")
|
||||
and mime_type not in EXCLUDED_IMAGE_TYPES
|
||||
)
|
||||
|
||||
|
||||
def is_supported_by_vision_llm(mime_type: str) -> bool:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import re
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from io import BytesIO
|
||||
from typing import IO
|
||||
|
||||
import bs4
|
||||
@@ -161,7 +162,7 @@ def format_document_soup(
|
||||
return strip_excessive_newlines_and_spaces(text)
|
||||
|
||||
|
||||
def parse_html_page_basic(text: str | IO[bytes]) -> str:
|
||||
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
|
||||
soup = bs4.BeautifulSoup(text, "html.parser")
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
@@ -196,6 +196,9 @@ class FileStoreDocumentBatchStorage(DocumentBatchStorage):
|
||||
for batch_file_name in batch_names:
|
||||
path_info = self.extract_path_info(batch_file_name)
|
||||
if path_info is None:
|
||||
logger.warning(
|
||||
f"Could not extract path info from batch file: {batch_file_name}"
|
||||
)
|
||||
continue
|
||||
new_batch_file_name = self._get_batch_file_name(path_info.batch_num)
|
||||
self.file_store.change_file_id(batch_file_name, new_batch_file_name)
|
||||
|
||||
@@ -49,11 +49,10 @@ def sanitize_s3_key_name(file_name: str) -> str:
|
||||
|
||||
# Characters to avoid completely (replace with underscore)
|
||||
# These are characters that AWS recommends avoiding
|
||||
avoid_chars = r'[\\{}^%`\[\]"<>#|~]'
|
||||
avoid_chars = r'[\\{}^%`\[\]"<>#|~/]'
|
||||
|
||||
# Replace avoided characters with underscore
|
||||
sanitized = re.sub(avoid_chars, "_", file_name)
|
||||
|
||||
# Characters that might require special handling but are allowed
|
||||
# We'll URL encode these to be safe
|
||||
special_chars = r"[&$@=;:+,?\s]"
|
||||
@@ -81,6 +80,9 @@ def sanitize_s3_key_name(file_name: str) -> str:
|
||||
# Remove any trailing periods to avoid download issues
|
||||
sanitized = sanitized.rstrip(".")
|
||||
|
||||
# Remove multiple separators
|
||||
sanitized = re.sub(r"[-_]{2,}", "-", sanitized)
|
||||
|
||||
# If sanitization resulted in empty string, use a default
|
||||
if not sanitized:
|
||||
sanitized = "sanitized_file"
|
||||
|
||||
@@ -46,7 +46,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
|
||||
# Get plaintext file name
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
|
||||
|
||||
# Use a separate session to avoid committing the caller's transaction
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
file_content = BytesIO(plaintext_content.encode("utf-8"))
|
||||
|
||||
@@ -142,6 +142,7 @@ def _upsert_documents_in_db(
|
||||
secondary_owners=get_experts_stores_representations(doc.secondary_owners),
|
||||
from_ingestion_api=doc.from_ingestion_api,
|
||||
external_access=doc.external_access,
|
||||
doc_metadata=doc.doc_metadata,
|
||||
)
|
||||
document_metadata_list.append(db_doc_metadata)
|
||||
|
||||
@@ -866,31 +867,27 @@ def index_doc_batch(
|
||||
user_file_id_to_raw_text: dict[int, str] = {}
|
||||
for document_id in updatable_ids:
|
||||
# Only calculate token counts for documents that have a user file ID
|
||||
if (
|
||||
document_id in doc_id_to_user_file_id
|
||||
and doc_id_to_user_file_id[document_id] is not None
|
||||
):
|
||||
user_file_id = doc_id_to_user_file_id[document_id]
|
||||
if not user_file_id:
|
||||
continue
|
||||
document_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
if document_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in document_chunks]
|
||||
)
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content))
|
||||
if llm_tokenizer
|
||||
else 0
|
||||
)
|
||||
user_file_id_to_token_count[user_file_id] = token_count
|
||||
user_file_id_to_raw_text[user_file_id] = combined_content
|
||||
else:
|
||||
user_file_id_to_token_count[user_file_id] = None
|
||||
|
||||
user_file_id = doc_id_to_user_file_id.get(document_id)
|
||||
if user_file_id is None:
|
||||
continue
|
||||
|
||||
document_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == document_id
|
||||
]
|
||||
if document_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in document_chunks]
|
||||
)
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
)
|
||||
user_file_id_to_token_count[user_file_id] = token_count
|
||||
user_file_id_to_raw_text[user_file_id] = combined_content
|
||||
else:
|
||||
user_file_id_to_token_count[user_file_id] = None
|
||||
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
|
||||
@@ -313,14 +313,14 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
self._model_kwargs = model_kwargs
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Config: {self.config}")
|
||||
|
||||
def _safe_model_config(self) -> dict:
|
||||
dump = self.config.model_dump()
|
||||
dump["api_key"] = mask_string(dump.get("api_key", ""))
|
||||
return dump
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Config: {self._safe_model_config()}")
|
||||
|
||||
def _record_call(self, prompt: LanguageModelInput) -> None:
|
||||
if self._long_term_logger:
|
||||
self._long_term_logger.record(
|
||||
@@ -397,7 +397,11 @@ class DefaultMultiLLM(LLM):
|
||||
# streaming choice
|
||||
stream=stream,
|
||||
# model params
|
||||
temperature=self._temperature,
|
||||
temperature=(
|
||||
1
|
||||
if self.config.model_name in ["gpt-5", "gpt-5-mini", "gpt-5-nano"]
|
||||
else self._temperature
|
||||
),
|
||||
timeout=timeout_override or self._timeout,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
|
||||
@@ -47,6 +47,9 @@ class WellKnownLLMProviderDescriptor(BaseModel):
|
||||
|
||||
OPENAI_PROVIDER_NAME = "openai"
|
||||
OPEN_AI_MODEL_NAMES = [
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
"o4-mini",
|
||||
"o3-mini",
|
||||
"o1-mini",
|
||||
@@ -73,7 +76,14 @@ OPEN_AI_MODEL_NAMES = [
|
||||
"gpt-3.5-turbo-16k-0613",
|
||||
"gpt-3.5-turbo-0301",
|
||||
]
|
||||
OPEN_AI_VISIBLE_MODEL_NAMES = ["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"]
|
||||
OPEN_AI_VISIBLE_MODEL_NAMES = [
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"o1",
|
||||
"o3-mini",
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
]
|
||||
|
||||
BEDROCK_PROVIDER_NAME = "bedrock"
|
||||
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named
|
||||
|
||||
@@ -151,7 +151,7 @@ def _build_ephemeral_publication_block(
|
||||
email=message_info.email,
|
||||
sender_id=message_info.sender_id,
|
||||
thread_messages=[],
|
||||
is_bot_msg=message_info.is_bot_msg,
|
||||
is_slash_command=message_info.is_slash_command,
|
||||
is_bot_dm=message_info.is_bot_dm,
|
||||
thread_to_respond=respond_ts,
|
||||
)
|
||||
@@ -225,10 +225,10 @@ def _build_doc_feedback_block(
|
||||
|
||||
def get_restate_blocks(
|
||||
msg: str,
|
||||
is_bot_msg: bool,
|
||||
is_slash_command: bool,
|
||||
) -> list[Block]:
|
||||
# Only the slash command needs this context because the user doesn't see their own input
|
||||
if not is_bot_msg:
|
||||
if not is_slash_command:
|
||||
return []
|
||||
|
||||
return [
|
||||
@@ -576,7 +576,7 @@ def build_slack_response_blocks(
|
||||
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
|
||||
if not skip_restated_question:
|
||||
restate_question_block = get_restate_blocks(
|
||||
message_info.thread_messages[-1].message, message_info.is_bot_msg
|
||||
message_info.thread_messages[-1].message, message_info.is_slash_command
|
||||
)
|
||||
else:
|
||||
restate_question_block = []
|
||||
|
||||
@@ -177,7 +177,7 @@ def handle_generate_answer_button(
|
||||
sender_id=user_id or None,
|
||||
email=email or None,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=False,
|
||||
is_slash_command=False,
|
||||
is_bot_dm=False,
|
||||
),
|
||||
slack_channel_config=slack_channel_config,
|
||||
|
||||
@@ -28,7 +28,7 @@ logger_base = setup_logger()
|
||||
|
||||
|
||||
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
|
||||
if details.is_bot_msg and details.sender_id:
|
||||
if details.is_slash_command and details.sender_id:
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
channel=details.channel_to_respond,
|
||||
@@ -124,11 +124,11 @@ def handle_message(
|
||||
messages = message_info.thread_messages
|
||||
sender_id = message_info.sender_id
|
||||
bypass_filters = message_info.bypass_filters
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
is_slash_command = message_info.is_slash_command
|
||||
is_bot_dm = message_info.is_bot_dm
|
||||
|
||||
action = "slack_message"
|
||||
if is_bot_msg:
|
||||
if is_slash_command:
|
||||
action = "slack_slash_message"
|
||||
elif bypass_filters:
|
||||
action = "slack_tag_message"
|
||||
@@ -197,7 +197,7 @@ def handle_message(
|
||||
|
||||
# If configured to respond to team members only, then cannot be used with a /OnyxBot command
|
||||
# which would just respond to the sender
|
||||
if send_to and is_bot_msg:
|
||||
if send_to and is_slash_command:
|
||||
if sender_id:
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
|
||||
@@ -81,15 +81,15 @@ def handle_regular_answer(
|
||||
messages = message_info.thread_messages
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
is_slash_command = message_info.is_slash_command
|
||||
|
||||
# Capture whether response mode for channel is ephemeral. Even if the channel is set
|
||||
# to respond with an ephemeral message, we still send as non-ephemeral if
|
||||
# the message is a dm with the Onyx bot.
|
||||
send_as_ephemeral = (
|
||||
slack_channel_config.channel_config.get("is_ephemeral", False)
|
||||
and not message_info.is_bot_dm
|
||||
)
|
||||
or message_info.is_slash_command
|
||||
) and not message_info.is_bot_dm
|
||||
|
||||
# If the channel mis configured to respond with an ephemeral message,
|
||||
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
|
||||
@@ -164,7 +164,7 @@ def handle_regular_answer(
|
||||
# in an attached document set were available to all users in the channel.)
|
||||
bypass_acl = False
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
if not message_ts_to_respond_to and not is_slash_command:
|
||||
# if the message is not "/onyx" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
"No message timestamp to respond to in `handle_message`. This should never happen."
|
||||
@@ -316,13 +316,14 @@ def handle_regular_answer(
|
||||
return True
|
||||
|
||||
# Got an answer at this point, can remove reaction and give results
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
if not is_slash_command: # Slash commands don't have reactions
|
||||
update_emote_react(
|
||||
emoji=DANSWER_REACT_EMOJI,
|
||||
channel=message_info.channel_to_respond,
|
||||
message_ts=message_info.msg_to_respond,
|
||||
remove=True,
|
||||
client=client,
|
||||
)
|
||||
|
||||
if answer.answer_valid is False:
|
||||
logger.notice(
|
||||
|
||||
@@ -876,12 +876,13 @@ def build_request_details(
|
||||
sender_id=sender_id,
|
||||
email=email,
|
||||
bypass_filters=tagged,
|
||||
is_bot_msg=False,
|
||||
is_slash_command=False,
|
||||
is_bot_dm=event.get("channel_type") == "im",
|
||||
)
|
||||
|
||||
elif req.type == "slash_commands":
|
||||
channel = req.payload["channel_id"]
|
||||
channel_name = req.payload["channel_name"]
|
||||
msg = req.payload["text"]
|
||||
sender = req.payload["user_id"]
|
||||
expert_info = expert_info_from_slack_id(
|
||||
@@ -899,8 +900,8 @@ def build_request_details(
|
||||
sender_id=sender,
|
||||
email=email,
|
||||
bypass_filters=True,
|
||||
is_bot_msg=True,
|
||||
is_bot_dm=False,
|
||||
is_slash_command=True,
|
||||
is_bot_dm=channel_name == "directmessage",
|
||||
)
|
||||
|
||||
raise RuntimeError("Programming fault, this should never happen.")
|
||||
|
||||
@@ -13,7 +13,7 @@ class SlackMessageInfo(BaseModel):
|
||||
sender_id: str | None
|
||||
email: str | None
|
||||
bypass_filters: bool # User has tagged @OnyxBot
|
||||
is_bot_msg: bool # User is using /OnyxBot
|
||||
is_slash_command: bool # User is using /OnyxBot
|
||||
is_bot_dm: bool # User is direct messaging to OnyxBot
|
||||
|
||||
|
||||
@@ -25,7 +25,7 @@ class ActionValuesEphemeralMessageMessageInfo(BaseModel):
|
||||
email: str | None
|
||||
sender_id: str | None
|
||||
thread_messages: list[ThreadMessage] | None
|
||||
is_bot_msg: bool | None
|
||||
is_slash_command: bool | None
|
||||
is_bot_dm: bool | None
|
||||
thread_to_respond: str | None
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import redis
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_connector_stop import RedisConnectorStop
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -31,11 +30,6 @@ class RedisConnector:
|
||||
tenant_id, cc_pair_id, self.redis
|
||||
)
|
||||
|
||||
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
|
||||
return RedisConnectorIndex(
|
||||
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_id_from_fence_key(key: str) -> str | None:
|
||||
"""
|
||||
@@ -81,3 +75,11 @@ class RedisConnector:
|
||||
|
||||
object_id = parts[1]
|
||||
return object_id
|
||||
|
||||
def db_lock_key(self, search_settings_id: int) -> str:
|
||||
"""
|
||||
Key for the db lock for an indexing attempt.
|
||||
Prevents multiple modifications to the current indexing attempt row
|
||||
from multiple docfetching/docprocessing tasks.
|
||||
"""
|
||||
return f"da_lock:indexing:db_{self.cc_pair_id}/{search_settings_id}"
|
||||
|
||||
@@ -1,126 +1,10 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
import redis
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
|
||||
|
||||
class RedisConnectorIndexPayload(BaseModel):
|
||||
index_attempt_id: int | None
|
||||
started: datetime | None
|
||||
submitted: datetime
|
||||
celery_task_id: str | None
|
||||
|
||||
|
||||
class RedisConnectorIndex:
|
||||
"""Manages interactions with redis for indexing tasks. Should only be accessed
|
||||
through RedisConnector."""
|
||||
|
||||
PREFIX = "connectorindexing"
|
||||
FENCE_PREFIX = f"{PREFIX}_fence" # "connectorindexing_fence"
|
||||
GENERATOR_TASK_PREFIX = PREFIX + "+generator" # "connectorindexing+generator_fence"
|
||||
GENERATOR_PROGRESS_PREFIX = (
|
||||
PREFIX + "_generator_progress"
|
||||
) # connectorindexing_generator_progress
|
||||
GENERATOR_COMPLETE_PREFIX = (
|
||||
PREFIX + "_generator_complete"
|
||||
) # connectorindexing_generator_complete
|
||||
|
||||
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
|
||||
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
|
||||
DB_LOCK_PREFIX = "da_lock:indexing:db"
|
||||
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
|
||||
|
||||
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
|
||||
TERMINATE_TTL = 600
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
# used to signal that the watchdog is running
|
||||
WATCHDOG_PREFIX = PREFIX + "_watchdog"
|
||||
WATCHDOG_TTL = 300
|
||||
|
||||
# used to signal that the connector itself is still running
|
||||
CONNECTOR_ACTIVE_PREFIX = PREFIX + "_connector_active"
|
||||
CONNECTOR_ACTIVE_TTL = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.cc_pair_id = cc_pair_id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
self.generator_complete_key = (
|
||||
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.filestore_lock_key = (
|
||||
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.generator_lock_key = (
|
||||
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.per_worker_lock_key = (
|
||||
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
self.terminate_key = (
|
||||
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
|
||||
)
|
||||
|
||||
def set_generator_complete(self, payload: int | None) -> None:
|
||||
if not payload:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
return
|
||||
|
||||
self.redis.set(self.generator_complete_key, payload)
|
||||
|
||||
def generator_clear(self) -> None:
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
def get_completion(self) -> int | None:
|
||||
bytes = self.redis.get(self.generator_complete_key)
|
||||
if bytes is None:
|
||||
return None
|
||||
|
||||
status = int(cast(int, bytes))
|
||||
return status
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.filestore_lock_key)
|
||||
self.redis.delete(self.db_lock_key)
|
||||
self.redis.delete(self.generator_lock_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
# leaving these temporarily for backwards compat, TODO: remove
|
||||
for key in r.scan_iter(RedisConnectorIndex.CONNECTOR_ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_PROGRESS_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
@@ -28,10 +28,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
|
||||
@property
|
||||
def fenced(self) -> bool:
|
||||
if self.redis.exists(self.fence_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(self.redis.exists(self.fence_key))
|
||||
|
||||
def set_fence(self, payload: int | None) -> None:
|
||||
if payload is None:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from onyx.redis.redis_connector_delete import RedisConnectorDelete
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_connector_prune import RedisConnectorPrune
|
||||
from onyx.redis.redis_document_set import RedisDocumentSet
|
||||
from onyx.redis.redis_usergroup import RedisUserGroup
|
||||
@@ -16,8 +15,6 @@ def is_fence(key_bytes: bytes) -> bool:
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
|
||||
return True
|
||||
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
|
||||
return True
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -101,8 +102,9 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.file_processing.extract_file_text import convert_docx_to_txt
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import AuthStatus
|
||||
from onyx.server.documents.models import AuthUrl
|
||||
@@ -124,6 +126,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
from onyx.server.documents.models import ObjectCreationIdResponse
|
||||
from onyx.server.documents.models import RunConnectorRequest
|
||||
from onyx.server.models import StatusResponse
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -418,7 +421,29 @@ def extract_zip_metadata(zf: zipfile.ZipFile) -> dict[str, Any]:
|
||||
return zip_metadata
|
||||
|
||||
|
||||
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
def is_zip_file(file: UploadFile) -> bool:
|
||||
"""
|
||||
Check if the file is a zip file by content type or filename.
|
||||
"""
|
||||
return bool(
|
||||
(
|
||||
file.content_type
|
||||
and file.content_type.startswith(
|
||||
(
|
||||
"application/zip",
|
||||
"application/x-zip-compressed", # May be this in Windows
|
||||
"application/x-zip",
|
||||
"multipart/x-zip",
|
||||
)
|
||||
)
|
||||
)
|
||||
or (file.filename and file.filename.lower().endswith(".zip"))
|
||||
)
|
||||
|
||||
|
||||
def upload_files(
|
||||
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
|
||||
) -> FileUploadResponse:
|
||||
for file in files:
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="File name cannot be empty")
|
||||
@@ -429,12 +454,13 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
return not any(part.startswith(".") for part in normalized_path.split(os.sep))
|
||||
|
||||
deduped_file_paths = []
|
||||
deduped_file_names = []
|
||||
zip_metadata = {}
|
||||
try:
|
||||
file_store = get_default_file_store()
|
||||
seen_zip = False
|
||||
for file in files:
|
||||
if file.content_type and file.content_type.startswith("application/zip"):
|
||||
if is_zip_file(file):
|
||||
if seen_zip:
|
||||
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
|
||||
seen_zip = True
|
||||
@@ -460,14 +486,24 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
file_type=mime_type,
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(os.path.basename(file_info))
|
||||
continue
|
||||
|
||||
# Special handling for docx files - only store the plaintext version
|
||||
if file.content_type and file.content_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
docx_file_id = convert_docx_to_txt(file, file_store)
|
||||
deduped_file_paths.append(docx_file_id)
|
||||
# For mypy, actual check happens at start of function
|
||||
assert file.filename is not None
|
||||
|
||||
# Special handling for doc files - only store the plaintext version
|
||||
file_type = mime_type_to_chat_file_type(file.content_type)
|
||||
if file_type == ChatFileType.DOC:
|
||||
extracted_text = extract_file_text(file.file, file.filename or "")
|
||||
text_file_id = file_store.save_file(
|
||||
content=io.BytesIO(extracted_text.encode()),
|
||||
display_name=file.filename,
|
||||
file_origin=file_origin,
|
||||
file_type="text/plain",
|
||||
)
|
||||
deduped_file_paths.append(text_file_id)
|
||||
deduped_file_names.append(file.filename)
|
||||
continue
|
||||
|
||||
# Default handling for all other file types
|
||||
@@ -478,10 +514,15 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
|
||||
file_type=file.content_type or "text/plain",
|
||||
)
|
||||
deduped_file_paths.append(file_id)
|
||||
deduped_file_names.append(file.filename)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return FileUploadResponse(file_paths=deduped_file_paths, zip_metadata=zip_metadata)
|
||||
return FileUploadResponse(
|
||||
file_paths=deduped_file_paths,
|
||||
file_names=deduped_file_names,
|
||||
zip_metadata=zip_metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/admin/connector/file/upload")
|
||||
@@ -489,7 +530,7 @@ def upload_files_api(
|
||||
files: list[UploadFile],
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
) -> FileUploadResponse:
|
||||
return upload_files(files)
|
||||
return upload_files(files, FileOrigin.OTHER)
|
||||
|
||||
|
||||
@router.get("/admin/connector")
|
||||
|
||||
@@ -475,6 +475,7 @@ class GoogleServiceAccountCredentialRequest(BaseModel):
|
||||
|
||||
class FileUploadResponse(BaseModel):
|
||||
file_paths: list[str]
|
||||
file_names: list[str]
|
||||
zip_metadata: dict[str, Any]
|
||||
|
||||
|
||||
|
||||
@@ -179,15 +179,19 @@ def get_kg_entity_types(
|
||||
) -> SourceAndEntityTypeView:
|
||||
# when using for the first time, populate with default entity types
|
||||
entity_types = {
|
||||
key: [EntityType.from_model(et) for et in ets]
|
||||
for key, ets in get_configured_entity_types(db_session=db_session).items()
|
||||
source_name: [EntityType.from_model(et) for et in ets]
|
||||
for source_name, ets in get_configured_entity_types(
|
||||
db_session=db_session
|
||||
).items()
|
||||
}
|
||||
|
||||
source_statistics = {
|
||||
key: SourceStatistics(
|
||||
source_name=key, last_updated=last_updated, entities_count=entities_count
|
||||
source_name: SourceStatistics(
|
||||
source_name=source_name,
|
||||
last_updated=last_updated,
|
||||
entities_count=entities_count,
|
||||
)
|
||||
for key, (
|
||||
for source_name, (
|
||||
last_updated,
|
||||
entities_count,
|
||||
) in get_entity_stats_by_grounded_source_name(db_session=db_session).items()
|
||||
|
||||
@@ -206,5 +206,5 @@ def create_deletion_attempt_for_connector_id(
|
||||
if cc_pair.connector.source == DocumentSource.FILE:
|
||||
connector = cc_pair.connector
|
||||
file_store = get_default_file_store()
|
||||
for file_name in connector.connector_specific_config.get("file_locations", []):
|
||||
file_store.delete_file(file_name)
|
||||
for file_id in connector.connector_specific_config.get("file_locations", []):
|
||||
file_store.delete_file(file_id)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
@@ -31,7 +30,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
|
||||
@@ -63,9 +61,7 @@ from onyx.db.models import User
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.user_documents import create_user_files
|
||||
from onyx.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@@ -717,105 +713,65 @@ def upload_files_for_chat(
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="File size must be less than 20MB",
|
||||
detail="Images must be less than 20MB",
|
||||
)
|
||||
|
||||
file_store = get_default_file_store()
|
||||
|
||||
file_info: list[tuple[str, str | None, ChatFileType]] = []
|
||||
for file in files:
|
||||
file_type = mime_type_to_chat_file_type(file.content_type)
|
||||
|
||||
file_content = file.file.read() # Read the file content
|
||||
|
||||
# NOTE: Image conversion to JPEG used to be enforced here.
|
||||
# This was removed to:
|
||||
# 1. Preserve original file content for downloads
|
||||
# 2. Maintain transparency in formats like PNG
|
||||
# 3. Ameliorate issue with file conversion
|
||||
file_content_io = io.BytesIO(file_content)
|
||||
|
||||
new_content_type = file.content_type
|
||||
|
||||
# Store the file normally
|
||||
file_id = file_store.save_file(
|
||||
content=file_content_io,
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=new_content_type or file_type.value,
|
||||
# 5) Create a user file for each uploaded file
|
||||
user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session)
|
||||
for user_file in user_files:
|
||||
# 6) Create connector
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
connector = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
# 4) If the file is a doc, extract text and store that separately
|
||||
if file_type == ChatFileType.DOC:
|
||||
# Re-wrap bytes in a fresh BytesIO so we start at position 0
|
||||
extracted_text_io = io.BytesIO(file_content)
|
||||
extracted_text = extract_file_text(
|
||||
file=extracted_text_io, # use the bytes we already read
|
||||
file_name=file.filename or "",
|
||||
)
|
||||
# 7) Create credential
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{int(time.time())}",
|
||||
is_user_file=True,
|
||||
)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
text_file_id = file_store.save_file(
|
||||
content=io.BytesIO(extracted_text.encode()),
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type="text/plain",
|
||||
)
|
||||
# Return the text file as the "main" file descriptor for doc types
|
||||
file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT))
|
||||
else:
|
||||
file_info.append((file_id, file.filename, file_type))
|
||||
|
||||
# 5) Create a user file for each uploaded file
|
||||
user_files = create_user_files([file], RECENT_DOCS_FOLDER_ID, user, db_session)
|
||||
for user_file in user_files:
|
||||
# 6) Create connector
|
||||
connector_base = ConnectorBase(
|
||||
name=f"UserFile-{int(time.time())}",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
)
|
||||
connector = create_connector(
|
||||
db_session=db_session,
|
||||
connector_data=connector_base,
|
||||
)
|
||||
|
||||
# 7) Create credential
|
||||
credential_info = CredentialBase(
|
||||
credential_json={},
|
||||
admin_public=True,
|
||||
source=DocumentSource.FILE,
|
||||
curator_public=True,
|
||||
groups=[],
|
||||
name=f"UserFileCredential-{int(time.time())}",
|
||||
is_user_file=True,
|
||||
)
|
||||
credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# 8) Create connector credential pair
|
||||
cc_pair = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
db_session.commit()
|
||||
# 8) Create connector credential pair
|
||||
cc_pair = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
connector_id=connector.id,
|
||||
credential_id=credential.id,
|
||||
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
|
||||
access_type=AccessType.PRIVATE,
|
||||
auto_sync_options=None,
|
||||
groups=[],
|
||||
)
|
||||
user_file.cc_pair_id = cc_pair.data
|
||||
db_session.commit()
|
||||
|
||||
return {
|
||||
"files": [
|
||||
{"id": file_id, "type": file_type, "name": file_name}
|
||||
for file_id, file_name, file_type in file_info
|
||||
{
|
||||
"id": user_file.file_id,
|
||||
"type": mime_type_to_chat_file_type(user_file.content_type),
|
||||
"name": user_file.name,
|
||||
}
|
||||
for user_file in user_files
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -408,6 +408,7 @@ def create_file_from_link(
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": [user_file.file_id],
|
||||
"file_names": [user_file.name],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
|
||||
@@ -44,12 +44,12 @@ litellm==1.72.2
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
Mako==1.2.4
|
||||
markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2
|
||||
msal==1.28.0
|
||||
nltk==3.9.1
|
||||
Office365-REST-Python-Client==2.5.9
|
||||
oauthlib==3.2.2
|
||||
openai==1.75.0
|
||||
openpyxl==3.0.10
|
||||
passlib==1.7.4
|
||||
playwright==1.41.2
|
||||
psutil==5.9.5
|
||||
@@ -66,7 +66,7 @@ pypdf==5.4.0
|
||||
pytest-mock==3.12.0
|
||||
pytest-playwright==0.7.0
|
||||
python-docx==1.1.2
|
||||
python-dotenv==1.0.0
|
||||
python-dotenv==1.1.1
|
||||
python-multipart==0.0.20
|
||||
pywikibot==9.0.0
|
||||
redis==5.0.8
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.configs.app_configs import REDIS_SSL
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_index import RedisConnectorIndex
|
||||
from onyx.redis.redis_pool import RedisPool
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -130,9 +129,6 @@ def onyx_redis(
|
||||
logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.")
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*"
|
||||
purge_by_match_and_type(match_pattern, "string", batch, dry_run, r)
|
||||
|
||||
redis_delete_if_exists_helper(
|
||||
f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r
|
||||
)
|
||||
|
||||
@@ -187,7 +187,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
f"{connector_id} and Credential ID: {credential_id} does not exist."
|
||||
)
|
||||
|
||||
file_names: list[str] = (
|
||||
file_ids: list[str] = (
|
||||
cc_pair.connector.connector_specific_config["file_locations"]
|
||||
if cc_pair.connector.source == DocumentSource.FILE
|
||||
else []
|
||||
@@ -211,12 +211,12 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete connector due to {e}")
|
||||
|
||||
if file_names:
|
||||
if file_ids:
|
||||
logger.notice("Deleting stored files!")
|
||||
file_store = get_default_file_store()
|
||||
for file_name in file_names:
|
||||
logger.notice(f"Deleting file {file_name}")
|
||||
file_store.delete_file(file_name)
|
||||
for file_id in file_ids:
|
||||
logger.notice(f"Deleting file {file_id}")
|
||||
file_store.delete_file(file_id)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.confluence.connector import ConfluenceConnector
|
||||
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
|
||||
|
||||
|
||||
@@ -101,7 +103,17 @@ def test_confluence_connector_restriction_handling(
|
||||
mock_cc_pair.credential_id = 1
|
||||
|
||||
# Call the confluence_doc_sync function directly with the mock cc_pair
|
||||
doc_access_generator = confluence_doc_sync(mock_cc_pair, lambda: [], None)
|
||||
def mock_fetch_all_docs_fn(
|
||||
sort_order: SortOrder | None = None,
|
||||
) -> list[DocumentRow]:
|
||||
return []
|
||||
|
||||
def mock_fetch_all_docs_ids_fn() -> list[str]:
|
||||
return []
|
||||
|
||||
doc_access_generator = confluence_doc_sync(
|
||||
mock_cc_pair, mock_fetch_all_docs_fn, mock_fetch_all_docs_ids_fn, None
|
||||
)
|
||||
doc_access_list = list(doc_access_generator)
|
||||
assert len(doc_access_list) == 7
|
||||
assert all(
|
||||
|
||||
@@ -56,7 +56,9 @@ def test_single_text_file_with_metadata(
|
||||
"onyx.connectors.file.connector.get_default_file_store",
|
||||
return_value=mock_file_store,
|
||||
):
|
||||
connector = LocalFileConnector(file_locations=["test.txt"], zip_metadata={})
|
||||
connector = LocalFileConnector(
|
||||
file_locations=["test.txt"], file_names=["test.txt"], zip_metadata={}
|
||||
)
|
||||
batches = list(connector.load_from_state())
|
||||
|
||||
assert len(batches) == 1
|
||||
@@ -113,7 +115,9 @@ def test_two_text_files_with_zip_metadata(
|
||||
return_value=mock_file_store,
|
||||
):
|
||||
connector = LocalFileConnector(
|
||||
file_locations=["file1.txt", "file2.txt"], zip_metadata=zip_metadata
|
||||
file_locations=["file1.txt", "file2.txt"],
|
||||
file_names=["file1.txt", "file2.txt"],
|
||||
zip_metadata=zip_metadata,
|
||||
)
|
||||
batches = list(connector.load_from_state())
|
||||
|
||||
|
||||
@@ -10,6 +10,8 @@ from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
|
||||
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ACCESS_MAPPING
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
@@ -71,7 +73,20 @@ def test_gdrive_perm_sync_with_real_data(
|
||||
return_value=_build_connector(google_drive_service_acct_connector_factory),
|
||||
):
|
||||
# Call the function under test
|
||||
doc_access_generator = gdrive_doc_sync(mock_cc_pair, lambda: [], mock_heartbeat)
|
||||
def mock_fetch_all_docs_fn(
|
||||
sort_order: SortOrder | None = None,
|
||||
) -> list[DocumentRow]:
|
||||
return []
|
||||
|
||||
def mock_fetch_all_docs_ids_fn() -> list[str]:
|
||||
return []
|
||||
|
||||
doc_access_generator = gdrive_doc_sync(
|
||||
mock_cc_pair,
|
||||
mock_fetch_all_docs_fn,
|
||||
mock_fetch_all_docs_ids_fn,
|
||||
mock_heartbeat,
|
||||
)
|
||||
doc_access_list = list(doc_access_generator)
|
||||
|
||||
# Verify we got some results
|
||||
|
||||
@@ -11,6 +11,7 @@ import pytest
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
|
||||
|
||||
def extract_key_value_pairs_to_set(
|
||||
@@ -35,7 +36,7 @@ def _load_reference_data(
|
||||
@pytest.fixture
|
||||
def salesforce_connector() -> SalesforceConnector:
|
||||
connector = SalesforceConnector(
|
||||
requested_objects=["Account", "Contact", "Opportunity"],
|
||||
requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact", "Opportunity"],
|
||||
)
|
||||
|
||||
username = os.environ["SF_USERNAME"]
|
||||
|
||||
@@ -85,12 +85,12 @@ def sharepoint_credentials() -> dict[str, str]:
|
||||
}
|
||||
|
||||
|
||||
def test_sharepoint_connector_all_sites(
|
||||
def test_sharepoint_connector_all_sites__docs_only(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with no sites
|
||||
connector = SharepointConnector()
|
||||
connector = SharepointConnector(include_site_pages=False)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
@@ -135,12 +135,14 @@ def test_sharepoint_connector_specific_folder(
|
||||
verify_document_content(doc, expected)
|
||||
|
||||
|
||||
def test_sharepoint_connector_root_folder(
|
||||
def test_sharepoint_connector_root_folder__docs_only(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
|
||||
connector = SharepointConnector(
|
||||
sites=[os.environ["SHAREPOINT_SITE"]], include_site_pages=False
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
@@ -225,3 +227,59 @@ def test_sharepoint_connector_poll(
|
||||
verify_document_content(
|
||||
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
|
||||
)
|
||||
|
||||
|
||||
def test_sharepoint_connector_pages(
|
||||
mock_get_unstructured_api_key: MagicMock,
|
||||
sharepoint_credentials: dict[str, str],
|
||||
) -> None:
|
||||
# Initialize connector with the base site URL
|
||||
connector = SharepointConnector(
|
||||
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests-pages"]
|
||||
)
|
||||
|
||||
# Load credentials
|
||||
connector.load_credentials(sharepoint_credentials)
|
||||
|
||||
# Get documents within the time window
|
||||
document_batches = list(connector.load_from_state())
|
||||
found_documents: list[Document] = [
|
||||
doc for batch in document_batches for doc in batch
|
||||
]
|
||||
|
||||
# Should only find CollabHome
|
||||
assert len(found_documents) == 1, "Should only find one page"
|
||||
doc = found_documents[0]
|
||||
assert doc.semantic_identifier == "CollabHome"
|
||||
verify_document_metadata(doc)
|
||||
assert len(doc.sections) == 1
|
||||
assert (
|
||||
doc.sections[0].text
|
||||
== """
|
||||
# Home
|
||||
|
||||
Display recent news.
|
||||
|
||||
## News
|
||||
|
||||
Show recent activities from your site
|
||||
|
||||
## Site activity
|
||||
|
||||
## Quick links
|
||||
|
||||
Learn about a team site
|
||||
|
||||
Learn how to add a page
|
||||
|
||||
Add links to important documents and pages.
|
||||
|
||||
## Quick links
|
||||
|
||||
Documents
|
||||
|
||||
Add a document library
|
||||
|
||||
## Document library
|
||||
""".strip()
|
||||
)
|
||||
|
||||
@@ -33,7 +33,7 @@ def _load_all_docs(
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
raise RuntimeError(f"Failed to load documents: {failure}")
|
||||
if document is not None:
|
||||
if document is not None and isinstance(document, Document):
|
||||
documents.append(document)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
@@ -100,7 +100,7 @@ def load_everything_from_checkpoint_connector(
|
||||
for document, failure, next_checkpoint in doc_batch_generator:
|
||||
if failure is not None:
|
||||
outputs.append(failure)
|
||||
if document is not None:
|
||||
if document is not None and isinstance(document, Document):
|
||||
outputs.append(document)
|
||||
if next_checkpoint is not None:
|
||||
checkpoint = next_checkpoint
|
||||
|
||||
@@ -34,7 +34,7 @@ class ConnectorManager:
|
||||
connector_specific_config=(
|
||||
connector_specific_config
|
||||
or (
|
||||
{"file_locations": [], "zip_metadata": {}}
|
||||
{"file_locations": [], "file_names": [], "zip_metadata": {}}
|
||||
if source == DocumentSource.FILE
|
||||
else {}
|
||||
)
|
||||
|
||||
Binary file not shown.
@@ -21,6 +21,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
FILE_NAME = "Sample.pdf"
|
||||
FILE_PATH = "tests/integration/common_utils/test_files"
|
||||
DOCX_FILE_NAME = "three_images.docx"
|
||||
|
||||
|
||||
def test_image_indexing(
|
||||
@@ -67,7 +68,11 @@ def test_image_indexing(
|
||||
name=connector_name,
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={"file_locations": file_paths, "zip_metadata": {}},
|
||||
connector_specific_config={
|
||||
"file_locations": file_paths,
|
||||
"file_names": [FILE_NAME],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
user_performing_action=admin_user,
|
||||
@@ -110,3 +115,112 @@ def test_image_indexing(
|
||||
else:
|
||||
assert document.image_file_id is not None
|
||||
assert file_paths[0] in document.image_file_id
|
||||
|
||||
|
||||
def test_docx_image_indexing(
|
||||
reset: None,
|
||||
admin_user: DATestUser,
|
||||
vespa_client: vespa_fixture,
|
||||
) -> None:
|
||||
"""Test that images from docx files are correctly extracted and indexed."""
|
||||
os.makedirs(FILE_PATH, exist_ok=True)
|
||||
test_file_path = os.path.join(FILE_PATH, DOCX_FILE_NAME)
|
||||
|
||||
# Use FileManager to upload the test file
|
||||
upload_response = FileManager.upload_file_for_connector(
|
||||
file_path=test_file_path,
|
||||
file_name=DOCX_FILE_NAME,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
LLMProviderManager.create(
|
||||
name="test_llm_docx",
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
SettingsManager.update_settings(
|
||||
DATestSettings(
|
||||
search_time_image_analysis_enabled=True,
|
||||
image_extraction_and_analysis_enabled=True,
|
||||
),
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
file_paths = upload_response.file_paths
|
||||
|
||||
if not file_paths:
|
||||
pytest.fail("File upload failed - no file paths returned")
|
||||
|
||||
# Create a dummy credential for the file connector
|
||||
credential = CredentialManager.create(
|
||||
source=DocumentSource.FILE,
|
||||
credential_json={},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Create the connector
|
||||
connector_name = f"DocxFileConnector-{int(datetime.now().timestamp())}"
|
||||
connector = ConnectorManager.create(
|
||||
name=connector_name,
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": file_paths,
|
||||
"file_names": [DOCX_FILE_NAME],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
access_type=AccessType.PUBLIC,
|
||||
groups=[],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Link the credential to the connector
|
||||
cc_pair = CCPairManager.create(
|
||||
credential_id=credential.id,
|
||||
connector_id=connector.id,
|
||||
access_type=AccessType.PUBLIC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Explicitly run the connector to start indexing
|
||||
CCPairManager.run_once(
|
||||
cc_pair=cc_pair,
|
||||
from_beginning=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
CCPairManager.wait_for_indexing_completion(
|
||||
cc_pair=cc_pair,
|
||||
after=datetime.now(timezone.utc),
|
||||
timeout=300,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Fetch documents from Vespa - expect text content plus 3 images
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
vespa_client=vespa_client,
|
||||
)
|
||||
|
||||
# Should have documents for text content plus 3 images
|
||||
assert (
|
||||
len(documents) >= 3
|
||||
), f"Expected at least 3 documents (3 images), got {len(documents)}"
|
||||
|
||||
# Count documents with images
|
||||
image_documents = [doc for doc in documents if doc.image_file_id is not None]
|
||||
text_documents = [doc for doc in documents if doc.image_file_id is None]
|
||||
|
||||
assert (
|
||||
len(image_documents) == 3
|
||||
), f"Expected exactly 3 image documents, got {len(image_documents)}"
|
||||
assert (
|
||||
len(text_documents) >= 1
|
||||
), f"Expected at least 1 text document, got {len(text_documents)}"
|
||||
|
||||
# Verify each image document has a valid image_file_id pointing to our uploaded file
|
||||
for image_doc in image_documents:
|
||||
assert file_paths[0] in (
|
||||
image_doc.image_file_id or ""
|
||||
), f"Image document should reference uploaded file: {image_doc.image_file_id}"
|
||||
|
||||
@@ -76,6 +76,7 @@ def test_zip_metadata_handling(
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={
|
||||
"file_locations": file_paths,
|
||||
"file_names": [os.path.basename(file_path) for file_path in file_paths],
|
||||
"zip_metadata": metadata,
|
||||
},
|
||||
access_type=AccessType.PUBLIC,
|
||||
|
||||
@@ -65,7 +65,11 @@ def kg_test_docs() -> tuple[list[str], int, list[KGEntityType]]:
|
||||
name="KG-Test-FileConnector",
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={"file_locations": [], "zip_metadata": {}},
|
||||
connector_specific_config={
|
||||
"file_locations": [],
|
||||
"file_names": [],
|
||||
"zip_metadata": {},
|
||||
},
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
api_key = APIKeyManager.create(user_performing_action=admin_user)
|
||||
|
||||
@@ -160,7 +160,11 @@ def create_connector(env_name: str, file_paths: list[str]) -> int:
|
||||
name=connector_name,
|
||||
source=DocumentSource.FILE,
|
||||
input_type=InputType.LOAD_STATE,
|
||||
connector_specific_config={"file_locations": file_paths, "zip_metadata": {}},
|
||||
connector_specific_config={
|
||||
"file_locations": file_paths,
|
||||
"file_names": [], # For regression tests, no need for file_names
|
||||
"zip_metadata": {},
|
||||
},
|
||||
refresh_freq=None,
|
||||
prune_freq=None,
|
||||
indexing_start=None,
|
||||
|
||||
@@ -23,7 +23,7 @@ from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
from onyx.connectors.github.connector import GithubConnector
|
||||
from onyx.connectors.github.connector import GithubConnectorStage
|
||||
from onyx.connectors.github.connector import SerializedRepository
|
||||
from onyx.connectors.github.models import SerializedRepository
|
||||
from onyx.connectors.models import Document
|
||||
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
|
||||
from tests.unit.onyx.connectors.utils import (
|
||||
@@ -97,6 +97,10 @@ def create_mock_pr() -> Callable[..., MagicMock]:
|
||||
else f"https://github.com/test-org/test-repo/pull/{number}"
|
||||
)
|
||||
mock_pr.raw_data = {}
|
||||
mock_pr.base = MagicMock()
|
||||
mock_pr.base.repo = MagicMock()
|
||||
mock_pr.base.repo.full_name = "test-org/test-repo"
|
||||
|
||||
return mock_pr
|
||||
|
||||
return _create_mock_pr
|
||||
@@ -121,6 +125,11 @@ def create_mock_issue() -> Callable[..., MagicMock]:
|
||||
mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}"
|
||||
mock_issue.pull_request = None # Not a PR
|
||||
mock_issue.raw_data = {}
|
||||
|
||||
# Mock the nested base.repo.full_name attribute
|
||||
mock_issue.repository = MagicMock()
|
||||
mock_issue.repository.full_name = "test-org/test-repo"
|
||||
|
||||
return mock_issue
|
||||
|
||||
return _create_mock_issue
|
||||
@@ -265,7 +274,7 @@ def test_load_from_checkpoint_with_rate_limit(
|
||||
# Call load_from_checkpoint
|
||||
end_time = time.time()
|
||||
with patch(
|
||||
"onyx.connectors.github.connector._sleep_after_rate_limit_exception"
|
||||
"onyx.connectors.github.connector.sleep_after_rate_limit_exception"
|
||||
) as mock_sleep:
|
||||
outputs = load_everything_from_checkpoint_connector(
|
||||
github_connector, 0, end_time
|
||||
@@ -797,7 +806,7 @@ def test_load_from_checkpoint_cursor_pagination_completion(
|
||||
mock_repo1.get_issues.return_value = mock_empty_issues_list
|
||||
mock_repo2.get_issues.return_value = mock_empty_issues_list
|
||||
with patch.object(
|
||||
github_connector, "_get_all_repos", return_value=[mock_repo1, mock_repo2]
|
||||
github_connector, "get_all_repos", return_value=[mock_repo1, mock_repo2]
|
||||
), patch.object(
|
||||
github_connector,
|
||||
"_pull_requests_func",
|
||||
|
||||
@@ -34,10 +34,16 @@ def mock_fetch_all_existing_docs_fn() -> MagicMock:
|
||||
return MagicMock(return_value=[])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_fetch_all_existing_docs_ids_fn() -> MagicMock:
|
||||
return MagicMock(return_value=[])
|
||||
|
||||
|
||||
def test_jira_permission_sync(
|
||||
jira_connector: JiraConnector,
|
||||
mock_jira_cc_pair: MagicMock,
|
||||
mock_fetch_all_existing_docs_fn: MagicMock,
|
||||
mock_fetch_all_existing_docs_ids_fn: MagicMock,
|
||||
) -> None:
|
||||
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
|
||||
mock_build_client.return_value = jira_connector._jira_client
|
||||
@@ -45,5 +51,6 @@ def test_jira_permission_sync(
|
||||
for doc in jira_doc_sync(
|
||||
cc_pair=mock_jira_cc_pair,
|
||||
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
|
||||
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
|
||||
):
|
||||
print(doc)
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Test script for the new custom query configuration functionality in SalesforceConnector.
|
||||
|
||||
This demonstrates how to use the new custom_query_config parameter to specify
|
||||
exactly which fields and associations (child objects) to retrieve for each object type.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from onyx.connectors.salesforce.connector import _validate_custom_query_config
|
||||
from onyx.connectors.salesforce.connector import SalesforceConnector
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
|
||||
|
||||
def test_custom_query_config() -> None:
|
||||
"""Test the custom query configuration functionality."""
|
||||
|
||||
# Example custom query configuration
|
||||
# This specifies exactly which fields and associations to retrieve
|
||||
custom_config = {
|
||||
ACCOUNT_OBJECT_TYPE: {
|
||||
"fields": ["Id", "Name", "Industry", "CreatedDate", MODIFIED_FIELD],
|
||||
"associations": {
|
||||
"Contact": ["Id", "FirstName", "LastName", "Email"],
|
||||
"Opportunity": ["Id", "Name", "StageName", "Amount", "CloseDate"],
|
||||
},
|
||||
},
|
||||
"Lead": {
|
||||
"fields": ["Id", "FirstName", "LastName", "Company", "Status"],
|
||||
"associations": {}, # No associations for Lead
|
||||
},
|
||||
}
|
||||
|
||||
# Create connector with custom configuration
|
||||
connector = SalesforceConnector(
|
||||
batch_size=50, custom_query_config=json.dumps(custom_config)
|
||||
)
|
||||
|
||||
print("✅ SalesforceConnector created successfully with custom query config")
|
||||
print(f"Parent object list: {connector.parent_object_list}")
|
||||
print(f"Custom config keys: {list(custom_config.keys())}")
|
||||
|
||||
# Test that the parent object list is derived from the custom config
|
||||
assert connector.parent_object_list == [ACCOUNT_OBJECT_TYPE, "Lead"]
|
||||
assert connector.custom_query_config == custom_config
|
||||
|
||||
print("✅ Basic validation passed")
|
||||
|
||||
|
||||
def test_traditional_config() -> None:
|
||||
"""Test that the traditional requested_objects approach still works."""
|
||||
|
||||
# Traditional approach
|
||||
connector = SalesforceConnector(
|
||||
batch_size=50, requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact"]
|
||||
)
|
||||
|
||||
print("✅ SalesforceConnector created successfully with traditional config")
|
||||
print(f"Parent object list: {connector.parent_object_list}")
|
||||
|
||||
# Test that it still works the old way
|
||||
assert connector.parent_object_list == [ACCOUNT_OBJECT_TYPE, "Contact"]
|
||||
assert connector.custom_query_config is None
|
||||
|
||||
print("✅ Traditional config validation passed")
|
||||
|
||||
|
||||
def test_validation() -> None:
|
||||
"""Test that invalid configurations are rejected."""
|
||||
|
||||
# Test invalid config structure
|
||||
invalid_configs: list[Any] = [
|
||||
# Invalid fields type
|
||||
{ACCOUNT_OBJECT_TYPE: {"fields": "invalid"}},
|
||||
# Invalid associations type
|
||||
{ACCOUNT_OBJECT_TYPE: {"associations": "invalid"}},
|
||||
# Nested invalid structure
|
||||
{ACCOUNT_OBJECT_TYPE: {"associations": {"Contact": {"fields": "invalid"}}}},
|
||||
]
|
||||
|
||||
for i, invalid_config in enumerate(invalid_configs):
|
||||
try:
|
||||
_validate_custom_query_config(invalid_config)
|
||||
assert False, f"Should have raised ValueError for invalid_config[{i}]"
|
||||
except ValueError:
|
||||
print(f"✅ Correctly rejected invalid config {i}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing SalesforceConnector custom query configuration...")
|
||||
print("=" * 60)
|
||||
|
||||
test_custom_query_config()
|
||||
print()
|
||||
|
||||
test_traditional_config()
|
||||
print()
|
||||
|
||||
test_validation()
|
||||
print()
|
||||
|
||||
print("=" * 60)
|
||||
print("🎉 All tests passed! The custom query configuration is working correctly.")
|
||||
print()
|
||||
print("Example usage:")
|
||||
print(
|
||||
"""
|
||||
# Custom configuration approach
|
||||
custom_config = {
|
||||
ACCOUNT_OBJECT_TYPE: {
|
||||
"fields": ["Id", "Name", "Industry"],
|
||||
"associations": {
|
||||
"Contact": {
|
||||
"fields": ["Id", "FirstName", "LastName", "Email"],
|
||||
"associations": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
connector = SalesforceConnector(custom_query_config=custom_config)
|
||||
|
||||
# Traditional approach (still works)
|
||||
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact"])
|
||||
"""
|
||||
)
|
||||
@@ -26,6 +26,9 @@ from onyx.connectors.salesforce.salesforce_calls import _make_time_filter_for_sf
|
||||
from onyx.connectors.salesforce.salesforce_calls import _make_time_filtered_query
|
||||
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
|
||||
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
|
||||
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
|
||||
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
|
||||
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# from onyx.connectors.salesforce.onyx_salesforce_type import OnyxSalesforceType
|
||||
@@ -153,7 +156,7 @@ def _create_csv_file_and_update_db(
|
||||
Creates a CSV file for the given object type and records.
|
||||
|
||||
Args:
|
||||
object_type: The Salesforce object type (e.g. "Account", "Contact")
|
||||
object_type: The Salesforce object type (e.g. ACCOUNT_OBJECT_TYPE, "Contact")
|
||||
records: List of dictionaries containing the record data
|
||||
filename: Name of the CSV file to create (default: test_data.csv)
|
||||
"""
|
||||
@@ -184,7 +187,7 @@ def _create_csv_with_example_data(sf_db: OnyxSalesforceSQLite) -> None:
|
||||
Creates CSV files with example data, organized by object type.
|
||||
"""
|
||||
example_data: dict[str, list[dict]] = {
|
||||
"Account": [
|
||||
ACCOUNT_OBJECT_TYPE: [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Acme Inc.",
|
||||
@@ -428,7 +431,7 @@ def _test_query(sf_db: OnyxSalesforceSQLite) -> None:
|
||||
}
|
||||
|
||||
# Get all Account IDs
|
||||
account_ids = sf_db.find_ids_by_type("Account")
|
||||
account_ids = sf_db.find_ids_by_type(ACCOUNT_OBJECT_TYPE)
|
||||
|
||||
# Verify we found all expected accounts
|
||||
assert len(account_ids) == len(
|
||||
@@ -480,7 +483,9 @@ def _test_upsert(sf_db: OnyxSalesforceSQLite) -> None:
|
||||
},
|
||||
]
|
||||
|
||||
_create_csv_file_and_update_db(sf_db, "Account", update_data, "update_data.csv")
|
||||
_create_csv_file_and_update_db(
|
||||
sf_db, ACCOUNT_OBJECT_TYPE, update_data, "update_data.csv"
|
||||
)
|
||||
|
||||
# Verify the update worked
|
||||
updated_record = sf_db.get_record(_VALID_SALESFORCE_IDS[0])
|
||||
@@ -573,7 +578,7 @@ def _test_account_with_children(sf_db: OnyxSalesforceSQLite) -> None:
|
||||
3. Child object data is complete and accurate
|
||||
"""
|
||||
# First get all account IDs
|
||||
account_ids = sf_db.find_ids_by_type("Account")
|
||||
account_ids = sf_db.find_ids_by_type(ACCOUNT_OBJECT_TYPE)
|
||||
assert len(account_ids) > 0, "No accounts found"
|
||||
|
||||
# For each account, get its children and verify the data
|
||||
@@ -690,7 +695,7 @@ def _test_get_affected_parent_ids(sf_db: OnyxSalesforceSQLite) -> None:
|
||||
"""
|
||||
# Create test data with relationships
|
||||
test_data = {
|
||||
"Account": [
|
||||
ACCOUNT_OBJECT_TYPE: [
|
||||
{
|
||||
"Id": _VALID_SALESFORCE_IDS[0],
|
||||
"Name": "Parent Account 1",
|
||||
@@ -720,40 +725,46 @@ def _test_get_affected_parent_ids(sf_db: OnyxSalesforceSQLite) -> None:
|
||||
|
||||
# Test Case 1: Account directly in updated_ids and parent_types
|
||||
updated_ids = [_VALID_SALESFORCE_IDS[1]] # Parent Account 2
|
||||
parent_types = set(["Account"])
|
||||
parent_types = set([ACCOUNT_OBJECT_TYPE])
|
||||
affected_ids_by_type = defaultdict(set)
|
||||
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
|
||||
updated_ids, parent_types
|
||||
):
|
||||
affected_ids_by_type[parent_type].add(parent_id)
|
||||
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[1] in affected_ids_by_type["Account"]
|
||||
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
|
||||
), "Account type not in affected_ids_by_type"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[1] in affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
|
||||
), "Direct parent ID not included"
|
||||
|
||||
# Test Case 2: Account with child in updated_ids
|
||||
updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact
|
||||
parent_types = set(["Account"])
|
||||
parent_types = set([ACCOUNT_OBJECT_TYPE])
|
||||
affected_ids_by_type = defaultdict(set)
|
||||
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
|
||||
updated_ids, parent_types
|
||||
):
|
||||
affected_ids_by_type[parent_type].add(parent_id)
|
||||
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[0] in affected_ids_by_type["Account"]
|
||||
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
|
||||
), "Account type not in affected_ids_by_type"
|
||||
assert (
|
||||
_VALID_SALESFORCE_IDS[0] in affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
|
||||
), "Parent of updated child not included"
|
||||
|
||||
# Test Case 3: Both direct and indirect affects
|
||||
updated_ids = [_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]] # Both cases
|
||||
parent_types = set(["Account"])
|
||||
parent_types = set([ACCOUNT_OBJECT_TYPE])
|
||||
affected_ids_by_type = defaultdict(set)
|
||||
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
|
||||
updated_ids, parent_types
|
||||
):
|
||||
affected_ids_by_type[parent_type].add(parent_id)
|
||||
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
|
||||
affected_ids = affected_ids_by_type["Account"]
|
||||
assert (
|
||||
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
|
||||
), "Account type not in affected_ids_by_type"
|
||||
affected_ids = affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
|
||||
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
|
||||
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
|
||||
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
|
||||
@@ -929,7 +940,7 @@ def _get_child_records_by_id_query(
|
||||
object_id: str,
|
||||
sf_type: str,
|
||||
child_relationships: list[str],
|
||||
relationships_to_fields: dict[str, list[str]],
|
||||
relationships_to_fields: dict[str, set[str]],
|
||||
) -> str:
|
||||
"""Returns a SOQL query given the object id, type and child relationships.
|
||||
|
||||
@@ -963,7 +974,7 @@ def test_salesforce_connector_single() -> None:
|
||||
|
||||
# this record has some opportunity child records
|
||||
parent_id = "001bm00000BXfhEAAT"
|
||||
parent_type = "Account"
|
||||
parent_type = ACCOUNT_OBJECT_TYPE
|
||||
parent_types = [parent_type]
|
||||
|
||||
username = os.environ["SF_USERNAME"]
|
||||
@@ -987,11 +998,11 @@ def test_salesforce_connector_single() -> None:
|
||||
child_to_parent_types: dict[str, set[str]] = (
|
||||
{}
|
||||
) # reverse map from child to parent types
|
||||
child_relationship_to_queryable_fields: dict[str, list[str]] = {}
|
||||
child_relationship_to_queryable_fields: dict[str, set[str]] = {}
|
||||
|
||||
# parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
|
||||
|
||||
# Step 1 - make a list of all the types to download (parent + direct child + "User")
|
||||
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
|
||||
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
|
||||
for parent_type_working in parent_types:
|
||||
child_types_working = sf_client.get_children_of_sf_type(parent_type_working)
|
||||
@@ -1035,8 +1046,8 @@ def test_salesforce_connector_single() -> None:
|
||||
result = sf_client.query(query)
|
||||
records = result["records"]
|
||||
record = records[0]
|
||||
assert record["attributes"]["type"] == "Account"
|
||||
parent_last_modified_date = record.get("LastModifiedDate", "")
|
||||
assert record["attributes"]["type"] == ACCOUNT_OBJECT_TYPE
|
||||
parent_last_modified_date = record.get(MODIFIED_FIELD, "")
|
||||
parent_semantic_identifier = record.get("Name", "Unknown Object")
|
||||
parent_last_modified_by_id = record.get("LastModifiedById")
|
||||
|
||||
@@ -1163,9 +1174,9 @@ def test_salesforce_connector_single() -> None:
|
||||
# get user relationship if present
|
||||
primary_owner_list = None
|
||||
if parent_last_modified_by_id:
|
||||
queryable_user_fields = sf_client.get_queryable_fields_by_type("User")
|
||||
queryable_user_fields = sf_client.get_queryable_fields_by_type(USER_OBJECT_TYPE)
|
||||
query = get_object_by_id_query(
|
||||
parent_last_modified_by_id, "User", queryable_user_fields
|
||||
parent_last_modified_by_id, USER_OBJECT_TYPE, queryable_user_fields
|
||||
)
|
||||
result = sf_client.query(query)
|
||||
user_record = result["records"][0]
|
||||
|
||||
@@ -4,7 +4,7 @@ dependencies:
|
||||
version: 14.3.1
|
||||
- name: vespa
|
||||
repository: https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
version: 0.2.23
|
||||
version: 0.2.24
|
||||
- name: nginx
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 15.14.0
|
||||
@@ -14,5 +14,5 @@ dependencies:
|
||||
- name: minio
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 17.0.4
|
||||
digest: sha256:4c938cf9138e4ff6f5ecac5c044324d508ef2b0e1a23ba3f2bc089015cb40ff6
|
||||
generated: "2025-06-16T18:53:19.63168-07:00"
|
||||
digest: sha256:dddd687525764f5698adc339a11d268b0ee9c3ca81f8d46c9e65a6bf2c21cf25
|
||||
generated: "2025-08-06T19:00:41.218513-07:00"
|
||||
|
||||
@@ -5,7 +5,7 @@ home: https://www.onyx.app/
|
||||
sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.2.2
|
||||
version: 0.2.5
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
@@ -23,7 +23,7 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
condition: postgresql.enabled
|
||||
- name: vespa
|
||||
version: 0.2.23
|
||||
version: 0.2.24
|
||||
repository: https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
condition: vespa.enabled
|
||||
- name: nginx
|
||||
|
||||
@@ -123,15 +123,15 @@ function ActionForm({
|
||||
<button
|
||||
type="button"
|
||||
className="
|
||||
absolute
|
||||
bottom-4
|
||||
absolute
|
||||
bottom-4
|
||||
right-4
|
||||
border-border
|
||||
border
|
||||
bg-background
|
||||
rounded
|
||||
py-1
|
||||
px-3
|
||||
py-1
|
||||
px-3
|
||||
text-sm
|
||||
hover:bg-accent-background
|
||||
"
|
||||
@@ -162,7 +162,7 @@ function ActionForm({
|
||||
/>
|
||||
<div className="mt-4 text-sm bg-blue-50 text-blue-700 dark:text-blue-300 dark:bg-blue-900 p-4 rounded-md border border-blue-200 dark:border-blue-800">
|
||||
<Link
|
||||
href="https://docs.onyx.app/tools/custom"
|
||||
href="https://docs.onyx.app/actions/custom#custom-actions"
|
||||
className="text-link hover:underline flex items-center"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user