Compare commits

...

38 Commits

Author SHA1 Message Date
Evan Lohn
555630070b more sf logs 2025-08-07 19:18:37 -07:00
Evan Lohn
1d16c96009 fix: sf connector docs 2025-08-07 19:18:37 -07:00
Evan Lohn
297720c132 refactor: file processing (#5136)
* file processing refactor

* mypy

* CW comments

* address CW
2025-08-08 00:34:35 +00:00
Evan Lohn
bd4bd00cef feat: office parsing markitdown (#5115)
* switch to markitdown untested

* passing tests

* reset file

* dotenv version

* docs

* add test file

* add doc

* fix integration test
2025-08-07 23:26:02 +00:00
Chris Weaver
07c482f727 Make starter messages visible on smaller screens (#5170) 2025-08-07 16:49:18 -07:00
Wenxi
cf193dee29 feat: support gpt5 models (#5169)
* support gpt5 models

* gpt5mini visible
2025-08-07 12:35:46 -07:00
Evan Lohn
1b47fa2700 fix: remove erroneous error case and add valid error (#5163)
* fix: remove erroneous error case and add valid error

* also address docfetching-docprocessing limbo
2025-08-07 18:17:00 +00:00
Wenxi Onyx
e1a305d18a mask llm api key from logs 2025-08-07 00:01:29 -07:00
Evan Lohn
e2233d22c9 feat: salesforce custom query (#5158)
* WIP merged approach untested

* tested custom configs

* JT comments

* fix unit test

* CW comments

* fix unit test
2025-08-07 02:37:23 +00:00
Justin Tahara
20d1175312 feat(infra): Bump Vespa Helm Version (#5161)
* feat(infra): Bump Vespa Helm Version

* Adding the Chart.lock file
2025-08-06 19:06:18 -07:00
justin-tahara
7117774287 Revert that change. Let's do this properly 2025-08-06 18:54:21 -07:00
justin-tahara
77f2660bb2 feat(infra): Update Vespa Helm Chart Version 2025-08-06 18:53:02 -07:00
Wenxi
1b2f4f3b87 fix: slash command slackbot to respond in private msg (#5151)
* fix slash command slackbot to respond in private msg

* rename confusing variable. fix slash message response in DMs
2025-08-05 19:03:38 -07:00
Evan Lohn
d85b55a9d2 no more scheduled stalling (#5154) 2025-08-05 20:17:44 +00:00
Justin Tahara
e2bae5a2d9 fix(infra): Adding helm directory (#5156)
* feat(infra): Adding helm directory

* one more fix
2025-08-05 14:11:57 -07:00
Justin Tahara
cc9c76c4fb feat(infra): Release Charts on Github Pages (#5155) 2025-08-05 14:03:28 -07:00
Chris Weaver
258e08abcd feat: add customization via env vars for curator role (#5150)
* Add customization via env vars for curator role

* Simplify

* Simplify more

* Address comments
2025-08-05 09:58:36 -07:00
Evan Lohn
67047e42a7 fix: preserve error traces (#5152) 2025-08-05 09:44:55 -07:00
SubashMohan
146628e734 fix unsupported character error in minio migration (#5145)
* fix unsupported character error in minio migration

* slash fix
2025-08-04 12:42:07 -07:00
Wenxi
c1d4b08132 fix: minio file names (#5138)
* nit var clarity

* maintain file names in connector config for display

* remove unused util

* migration draft

* optional file names to not break existing instances

* backwards compatible

* backwards compatible

* migration logging

* update file ocnn tests

* unncessary none

* mypy + explanatory comments
2025-08-01 20:31:29 +00:00
Justin Tahara
f3f47d0709 feat(infra): Creating new helm chart action workflow (#5137)
* feat(infra) Creating new helm chart action workflow

* Adding the steps

* Adding in dependencies

* One more debug

* Adding a new step to install helm
2025-08-01 09:26:58 -07:00
Justin Tahara
fe26a1bfcc feat(infra): Codeowner for Helm directory (#5139) 2025-07-31 23:05:46 +00:00
Wenxi
554cd0f891 fix: accept multiple zip types and fallback to extension (#5135)
* accept multiple zip types and fallback to extension

* move zip check to util

* mypy nit
2025-07-30 22:21:16 +00:00
Raunak Bhagat
f87d3e9849 fix: Make ungrounded types have a default name when sending to the frontend (#5133)
* Update names in map-comprehension

* Make default name for ungrounded types public

* Return the default name for ungrounded entity-types

* Update backend/onyx/db/entities.py

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>

---------

Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
2025-07-30 20:46:30 +00:00
Rei Meguro
72cdada893 edit link to custom actions (#5129) 2025-07-30 15:08:39 +00:00
SubashMohan
c442ebaff6 Feature/GitHub permission sync (#4996)
* github perm sync initial draft

* introduce github  doc sync and perm sync

* remove specific start time check

* Refactor GitHub connector to use SlimCheckpointOutputWrapper for improved document handling

* Update GitHub sync frequency defaults from 30 minutes to 5 minutes

* Add stop signal handling and progress reporting in GitHub document sync

* Refactor tests for Confluence and Google Drive connectors to use a mock fetch function for document access

* change the doc_sync approach

* add static typing for ocument columns and where clause

* remove prefix logic in connector runner

* mypy fix

* code review changes

* mypy fix

* fix review comments

* add sort order

* Implement merge heads migration for Alembic and update Confluence and Google Drive test

* github unit tests fix

* delete merge head and rebase the docmetadata field migration

---------

Co-authored-by: Subash <subash@onyx.app>
2025-07-30 02:42:18 +00:00
Justin Tahara
56f16d107e feat(infra): Update helm version after new feature (#5120) 2025-07-29 16:31:35 -07:00
Justin Tahara
0157ae099a [Vespa] Update to optimized configuration pt.2 (#5113) 2025-07-28 20:42:31 +00:00
justin-tahara
565fb42457 Let's do this properly 2025-07-28 10:42:31 -07:00
justin-tahara
a50a8b4a12 [Vespa] Update to optimized configuration 2025-07-28 10:38:48 -07:00
Evan Lohn
4baf4e7d96 feat: pruning freq (#5097)
* pruning frequency increase

* add logs
2025-07-26 22:29:43 +00:00
Wenxi
8b7ab2eb66 onyx metadata minio fix + permissive unstructured fail (#5085) 2025-07-25 21:26:02 +00:00
Evan Lohn
1f75f3633e fix: sidebar ranges (#5084) 2025-07-25 19:46:47 +00:00
Evan Lohn
650884d76a fix: preserve error traces (#5083) 2025-07-25 18:56:11 +00:00
Wenxi
8722bdb414 typo (#5082) 2025-07-25 18:26:21 +00:00
Evan Lohn
71037678c3 attempt to fix parsing of tricky template files (#5080) 2025-07-25 02:18:35 +00:00
Chris Weaver
68de1015e1 feat: support aspx files (#5068)
* Support aspx files

* Add fetching of site pages

* Improve

* Small enhancement

* more improvements

* Improvements

* Fix tests
2025-07-24 19:19:24 -07:00
Evan Lohn
e2b3a6e144 fix: drive external links (#5079) 2025-07-24 17:42:12 -07:00
114 changed files with 3025 additions and 775 deletions

2
.github/CODEOWNERS vendored
View File

@@ -1 +1,3 @@
* @onyx-dot-app/onyx-core-team
# Helm charts Owners
/helm/ @justin-tahara

View File

@@ -0,0 +1,42 @@
name: Release Onyx Helm Charts
on:
push:
branches:
- main
permissions: write-all
jobs:
release:
permissions:
contents: write
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Configure Git
run: |
git config user.name "$GITHUB_ACTOR"
git config user.email "$GITHUB_ACTOR@users.noreply.github.com"
- name: Install Helm
uses: azure/setup-helm@v4
with:
version: v3.12.1
- name: Add Required Helm Repositories
run: |
helm repo add bitnami https://charts.bitnami.com/bitnami
helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts
helm repo update
- name: Run chart-releaser
uses: helm/chart-releaser-action@v1.7.0
with:
charts_dir: deployment/helm/charts
env:
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"

View File

@@ -0,0 +1,30 @@
"""add_doc_metadata_field_in_document_model
Revision ID: 3fc5d75723b3
Revises: 2f95e36923e6
Create Date: 2025-07-28 18:45:37.985406
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "3fc5d75723b3"
down_revision = "2f95e36923e6"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"document",
sa.Column(
"doc_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True
),
)
def downgrade() -> None:
op.drop_column("document", "doc_metadata")

View File

@@ -0,0 +1,132 @@
"""add file names to file connector config
Revision ID: 62c3a055a141
Revises: 3fc5d75723b3
Create Date: 2025-07-30 17:01:24.417551
"""
from alembic import op
import sqlalchemy as sa
import json
import os
import logging
# revision identifiers, used by Alembic.
revision = "62c3a055a141"
down_revision = "3fc5d75723b3"
branch_labels = None
depends_on = None
SKIP_FILE_NAME_MIGRATION = (
os.environ.get("SKIP_FILE_NAME_MIGRATION", "true").lower() == "true"
)
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
if SKIP_FILE_NAME_MIGRATION:
logger.info(
"Skipping file name migration. Hint: set SKIP_FILE_NAME_MIGRATION=false to run this migration"
)
return
logger.info("Running file name migration")
# Get connection
conn = op.get_bind()
# Get all FILE connectors with their configs
file_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'FILE'
"""
)
).fetchall()
for connector_id, config in file_connectors:
# Parse config if it's a string
if isinstance(config, str):
config = json.loads(config)
# Get file_locations list
file_locations = config.get("file_locations", [])
# Get display names for each file_id
file_names = []
for file_id in file_locations:
result = conn.execute(
sa.text(
"""
SELECT display_name
FROM file_record
WHERE file_id = :file_id
"""
),
{"file_id": file_id},
).fetchone()
if result:
file_names.append(result[0])
else:
file_names.append(file_id) # Should not happen
# Add file_names to config
new_config = dict(config)
new_config["file_names"] = file_names
# Update the connector
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
def downgrade() -> None:
# Get connection
conn = op.get_bind()
# Remove file_names from all FILE connectors
file_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'FILE'
"""
)
).fetchall()
for connector_id, config in file_connectors:
# Parse config if it's a string
if isinstance(config, str):
config = json.loads(config)
# Remove file_names if it exists
if "file_names" in config:
new_config = dict(config)
del new_config["file_names"]
# Update the connector
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{
"connector_id": connector_id,
"new_config": json.dumps(new_config),
},
)

View File

@@ -47,6 +47,7 @@ from onyx.connectors.factory import validate_ccpair_for_user
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import get_document_ids_for_connector_credential_pair
from onyx.db.document import get_documents_for_connector_credential_pair_limited_columns
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
@@ -58,7 +59,9 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.db.utils import DocumentRow
from onyx.db.utils import is_retryable_sqlalchemy_error
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
@@ -498,16 +501,31 @@ def connector_permission_sync_generator_task(
# this is can be used to determine documents that are "missing" and thus
# should no longer be accessible. The decision as to whether we should find
# every document during the doc sync process is connector-specific.
def fetch_all_existing_docs_fn() -> list[str]:
return get_document_ids_for_connector_credential_pair(
def fetch_all_existing_docs_fn(
sort_order: SortOrder | None = None,
) -> list[DocumentRow]:
result = get_documents_for_connector_credential_pair_limited_columns(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
sort_order=sort_order,
)
return list(result)
def fetch_all_existing_docs_ids_fn() -> list[str]:
result = get_document_ids_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector.id,
credential_id=cc_pair.credential.id,
)
return result
doc_sync_func = sync_config.doc_sync_config.doc_sync_func
document_external_accesses = doc_sync_func(
cc_pair, fetch_all_existing_docs_fn, callback
cc_pair,
fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn,
callback,
)
task_logger.info(

View File

@@ -71,6 +71,19 @@ GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
)
#####
# GitHub
#####
# In seconds, default is 5 minutes
GITHUB_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("GITHUB_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
#####
# Slack
#####

View File

@@ -18,9 +18,9 @@
<!-- <document type="danswer_chunk" mode="index" /> -->
{{ document_elements }}
</documents>
<nodes count="75">
<resources vcpu="8.0" memory="64.0Gb" architecture="arm64" storage-type="local"
disk="474.0Gb" />
<nodes count="60">
<resources vcpu="8.0" memory="128.0Gb" architecture="arm64" storage-type="local"
disk="475.0Gb" />
</nodes>
<engine>
<proton>

View File

@@ -6,6 +6,7 @@ https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.htm
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -25,6 +26,7 @@ CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync"
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""
@@ -43,7 +45,7 @@ def confluence_doc_sync(
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.CONFLUENCE,
slim_connector=confluence_connector,

View File

@@ -0,0 +1,294 @@
import json
from collections.abc import Generator
from github import Github
from github.Repository import Repository
from ee.onyx.external_permissions.github.utils import fetch_repository_team_slugs
from ee.onyx.external_permissions.github.utils import form_collaborators_group_id
from ee.onyx.external_permissions.github.utils import form_organization_group_id
from ee.onyx.external_permissions.github.utils import (
form_outside_collaborators_group_id,
)
from ee.onyx.external_permissions.github.utils import get_external_access_permission
from ee.onyx.external_permissions.github.utils import get_repository_visibility
from ee.onyx.external_permissions.github.utils import GitHubVisibility
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.connector import DocMetadata
from onyx.connectors.github.connector import GithubConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
GITHUB_DOC_SYNC_LABEL = "github_doc_sync"
def github_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
"""
Sync GitHub documents with external access permissions.
This function checks each repository for visibility/team changes and updates
document permissions accordingly without using checkpoints.
"""
logger.info(f"Starting GitHub document sync for CC pair ID: {cc_pair.id}")
# Initialize GitHub connector with credentials
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
logger.info("GitHub connector credentials loaded successfully")
if not github_connector.github_client:
logger.error("GitHub client initialization failed")
raise ValueError("github_client is required")
# Get all repositories from GitHub API
logger.info("Fetching all repositories from GitHub API")
try:
repos = []
if github_connector.repositories:
if "," in github_connector.repositories:
# Multiple repositories specified
repos = github_connector.get_github_repos(
github_connector.github_client
)
else:
# Single repository
repos = [
github_connector.get_github_repo(github_connector.github_client)
]
else:
# All repositories
repos = github_connector.get_all_repos(github_connector.github_client)
logger.info(f"Found {len(repos)} repositories to check")
except Exception as e:
logger.error(f"Failed to fetch repositories: {e}")
raise
repo_to_doc_list_map: dict[str, list[DocumentRow]] = {}
# sort order is ascending because we want to get the oldest documents first
existing_docs: list[DocumentRow] = fetch_all_existing_docs_fn(
sort_order=SortOrder.ASC
)
logger.info(f"Found {len(existing_docs)} documents to check")
for doc in existing_docs:
try:
doc_metadata = DocMetadata.model_validate_json(json.dumps(doc.doc_metadata))
if doc_metadata.repo not in repo_to_doc_list_map:
repo_to_doc_list_map[doc_metadata.repo] = []
repo_to_doc_list_map[doc_metadata.repo].append(doc)
except Exception as e:
logger.error(f"Failed to parse doc metadata: {e} for doc {doc.id}")
continue
logger.info(f"Found {len(repo_to_doc_list_map)} documents to check")
# Process each repository individually
for repo in repos:
try:
logger.info(f"Processing repository: {repo.id} (name: {repo.name})")
repo_doc_list: list[DocumentRow] = repo_to_doc_list_map.get(
repo.full_name, []
)
if not repo_doc_list:
logger.warning(
f"No documents found for repository {repo.id} ({repo.name})"
)
continue
current_external_group_ids = repo_doc_list[0].external_user_group_ids or []
# Check if repository has any permission changes
has_changes = _check_repository_for_changes(
repo=repo,
github_client=github_connector.github_client,
current_external_group_ids=current_external_group_ids,
)
if has_changes:
logger.info(
f"Repository {repo.id} ({repo.name}) has changes, updating documents"
)
# Get new external access permissions for this repository
new_external_access = get_external_access_permission(
repo, github_connector.github_client
)
logger.info(
f"Found {len(repo_doc_list)} documents for repository {repo.full_name}"
)
# Yield updated external access for each document
for doc in repo_doc_list:
if callback:
callback.progress(GITHUB_DOC_SYNC_LABEL, 1)
yield DocExternalAccess(
doc_id=doc.id,
external_access=new_external_access,
)
else:
logger.info(
f"Repository {repo.id} ({repo.name}) has no changes, skipping"
)
except Exception as e:
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")
logger.info(f"GitHub document sync completed for CC pair ID: {cc_pair.id}")
def _check_repository_for_changes(
repo: Repository,
github_client: Github,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository has any permission changes (visibility or team updates).
"""
logger.info(f"Checking repository {repo.id} ({repo.name}) for changes")
# Check for repository visibility changes using the sample document data
if _is_repo_visibility_changed_from_groups(
repo=repo,
current_external_group_ids=current_external_group_ids,
):
logger.info(f"Repository {repo.id} ({repo.name}) has visibility changes")
return True
# Check for team membership changes if repository is private
if get_repository_visibility(
repo
) == GitHubVisibility.PRIVATE and _teams_updated_from_groups(
repo=repo,
github_client=github_client,
current_external_group_ids=current_external_group_ids,
):
logger.info(f"Repository {repo.id} ({repo.name}) has team changes")
return True
logger.info(f"Repository {repo.id} ({repo.name}) has no changes")
return False
def _is_repo_visibility_changed_from_groups(
repo: Repository,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository visibility has changed by analyzing existing external group IDs.
Args:
repo: GitHub repository object
current_external_group_ids: List of external group IDs from existing document
Returns:
True if visibility has changed
"""
current_repo_visibility = get_repository_visibility(repo)
logger.info(f"Current repository visibility: {current_repo_visibility.value}")
# Build expected group IDs for current visibility
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_collaborators_group_id(repo.id),
)
org_group_id = None
if repo.organization:
org_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_organization_group_id(repo.organization.id),
)
# Determine existing visibility from group IDs
has_collaborators_group = collaborators_group_id in current_external_group_ids
has_org_group = org_group_id and org_group_id in current_external_group_ids
if has_collaborators_group:
existing_repo_visibility = GitHubVisibility.PRIVATE
elif has_org_group:
existing_repo_visibility = GitHubVisibility.INTERNAL
else:
existing_repo_visibility = GitHubVisibility.PUBLIC
logger.info(f"Inferred existing visibility: {existing_repo_visibility.value}")
visibility_changed = existing_repo_visibility != current_repo_visibility
if visibility_changed:
logger.info(
f"Visibility changed for repo {repo.id} ({repo.name}): "
f"{existing_repo_visibility.value} -> {current_repo_visibility.value}"
)
return visibility_changed
def _teams_updated_from_groups(
repo: Repository,
github_client: Github,
current_external_group_ids: list[str],
) -> bool:
"""
Check if repository team memberships have changed using existing group IDs.
"""
# Fetch current team slugs for the repository
current_teams = fetch_repository_team_slugs(repo=repo, github_client=github_client)
logger.info(
f"Current teams for repository {repo.id} (name: {repo.name}): {current_teams}"
)
# Build group IDs to exclude from team comparison (non-team groups)
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_collaborators_group_id(repo.id),
)
outside_collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=form_outside_collaborators_group_id(repo.id),
)
non_team_group_ids = {collaborators_group_id, outside_collaborators_group_id}
# Extract existing team IDs from current external group IDs
existing_team_ids = set()
for group_id in current_external_group_ids:
# Skip all non-team groups, keep only team groups
if group_id not in non_team_group_ids:
existing_team_ids.add(group_id)
# Note: existing_team_ids from DB are already prefixed (e.g., "github__team-slug")
# but current_teams from API are raw team slugs, so we need to add the prefix
current_team_ids = set()
for team_slug in current_teams:
team_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=team_slug,
)
current_team_ids.add(team_group_id)
logger.info(
f"Existing team IDs: {existing_team_ids}, Current team IDs: {current_team_ids}"
)
# Compare actual team IDs to detect changes
teams_changed = current_team_ids != existing_team_ids
if teams_changed:
logger.info(
f"Team changes detected for repo {repo.id} (name: {repo.name}): "
f"existing={existing_team_ids}, current={current_team_ids}"
)
return teams_changed

View File

@@ -0,0 +1,46 @@
from collections.abc import Generator
from github import Repository
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.github.utils import get_external_user_group
from onyx.connectors.github.connector import GithubConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.logger import setup_logger
logger = setup_logger()
def github_group_sync(
tenant_id: str,
cc_pair: ConnectorCredentialPair,
) -> Generator[ExternalUserGroup, None, None]:
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
if not github_connector.github_client:
raise ValueError("github_client is required")
logger.info("Starting GitHub group sync...")
repos: list[Repository.Repository] = []
if github_connector.repositories:
if "," in github_connector.repositories:
# Multiple repositories specified
repos = github_connector.get_github_repos(github_connector.github_client)
else:
# Single repository (backward compatibility)
repos = [github_connector.get_github_repo(github_connector.github_client)]
else:
# All repositories
repos = github_connector.get_all_repos(github_connector.github_client)
for repo in repos:
try:
for external_group in get_external_user_group(
repo, github_connector.github_client
):
logger.info(f"External group: {external_group}")
yield external_group
except Exception as e:
logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}")

View File

@@ -0,0 +1,488 @@
from collections.abc import Callable
from enum import Enum
from typing import List
from typing import Optional
from typing import Tuple
from typing import TypeVar
from github import Github
from github import RateLimitExceededException
from github.GithubException import GithubException
from github.NamedUser import NamedUser
from github.Organization import Organization
from github.PaginatedList import PaginatedList
from github.Repository import Repository
from github.Team import Team
from pydantic import BaseModel
from ee.onyx.db.external_perm import ExternalUserGroup
from onyx.access.models import ExternalAccess
from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
from onyx.utils.logger import setup_logger
logger = setup_logger()
class GitHubVisibility(Enum):
"""GitHub repository visibility options."""
PUBLIC = "public"
PRIVATE = "private"
INTERNAL = "internal"
MAX_RETRY_COUNT = 3
T = TypeVar("T")
# Higher-order function to wrap GitHub operations with retry and exception handling
def _run_with_retry(
operation: Callable[[], T],
description: str,
github_client: Github,
retry_count: int = 0,
) -> Optional[T]:
"""Execute a GitHub operation with retry on rate limit and exception handling."""
logger.debug(f"Starting operation '{description}', attempt {retry_count + 1}")
try:
result = operation()
logger.debug(f"Operation '{description}' completed successfully")
return result
except RateLimitExceededException:
if retry_count < MAX_RETRY_COUNT:
sleep_after_rate_limit_exception(github_client)
logger.warning(
f"Rate limit exceeded while {description}. Retrying... "
f"(attempt {retry_count + 1}/{MAX_RETRY_COUNT})"
)
return _run_with_retry(
operation, description, github_client, retry_count + 1
)
else:
error_msg = f"Max retries exceeded for {description}"
logger.exception(error_msg)
raise RuntimeError(error_msg)
except GithubException as e:
logger.warning(f"GitHub API error during {description}: {e}")
return None
except Exception as e:
logger.exception(f"Unexpected error during {description}: {e}")
return None
class UserInfo(BaseModel):
"""Represents a GitHub user with their basic information."""
login: str
name: Optional[str] = None
email: Optional[str] = None
class TeamInfo(BaseModel):
"""Represents a GitHub team with its members."""
name: str
slug: str
members: List[UserInfo]
def _fetch_organization_members(
github_client: Github, org_name: str, retry_count: int = 0
) -> List[UserInfo]:
"""Fetch all organization members including owners and regular members."""
org_members: List[UserInfo] = []
logger.info(f"Fetching organization members for {org_name}")
org = _run_with_retry(
lambda: github_client.get_organization(org_name),
f"get organization {org_name}",
github_client,
)
if not org:
logger.error(f"Failed to fetch organization {org_name}")
raise RuntimeError(f"Failed to fetch organization {org_name}")
member_objs: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: org.get_members(filter_="all"),
f"get members for organization {org_name}",
github_client,
)
or []
)
for member in member_objs:
user_info = UserInfo(login=member.login, name=member.name, email=member.email)
org_members.append(user_info)
logger.info(f"Fetched {len(org_members)} members for organization {org_name}")
return org_members
def _fetch_repository_teams_detailed(
repo: Repository, github_client: Github, retry_count: int = 0
) -> List[TeamInfo]:
"""Fetch teams with access to the repository and their members."""
teams_data: List[TeamInfo] = []
logger.info(f"Fetching teams for repository {repo.full_name}")
team_objs: PaginatedList[Team] | list[Team] = (
_run_with_retry(
lambda: repo.get_teams(),
f"get teams for repository {repo.full_name}",
github_client,
)
or []
)
for team in team_objs:
logger.info(
f"Processing team {team.name} (slug: {team.slug}) for repository {repo.full_name}"
)
members: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: team.get_members(),
f"get members for team {team.name}",
github_client,
)
or []
)
team_members = []
for m in members:
user_info = UserInfo(login=m.login, name=m.name, email=m.email)
team_members.append(user_info)
team_info = TeamInfo(name=team.name, slug=team.slug, members=team_members)
teams_data.append(team_info)
logger.info(f"Team {team.name} has {len(team_members)} members")
logger.info(f"Fetched {len(teams_data)} teams for repository {repo.full_name}")
return teams_data
def fetch_repository_team_slugs(
repo: Repository, github_client: Github, retry_count: int = 0
) -> List[str]:
"""Fetch team slugs with access to the repository."""
logger.info(f"Fetching team slugs for repository {repo.full_name}")
teams_data: List[str] = []
team_objs: PaginatedList[Team] | list[Team] = (
_run_with_retry(
lambda: repo.get_teams(),
f"get teams for repository {repo.full_name}",
github_client,
)
or []
)
for team in team_objs:
teams_data.append(team.slug)
logger.info(f"Fetched {len(teams_data)} team slugs for repository {repo.full_name}")
return teams_data
def _get_collaborators_and_outside_collaborators(
github_client: Github,
repo: Repository,
) -> Tuple[List[UserInfo], List[UserInfo]]:
"""Fetch and categorize collaborators into regular and outside collaborators."""
collaborators: List[UserInfo] = []
outside_collaborators: List[UserInfo] = []
logger.info(f"Fetching collaborators for repository {repo.full_name}")
repo_collaborators: PaginatedList[NamedUser] | list[NamedUser] = (
_run_with_retry(
lambda: repo.get_collaborators(),
f"get collaborators for repository {repo.full_name}",
github_client,
)
or []
)
for collaborator in repo_collaborators:
is_outside = False
# Check if collaborator is outside the organization
if repo.organization:
org: Organization | None = _run_with_retry(
lambda: github_client.get_organization(repo.organization.login),
f"get organization {repo.organization.login}",
github_client,
)
if org is not None:
org_obj = org
membership = _run_with_retry(
lambda: org_obj.has_in_members(collaborator),
f"check membership for {collaborator.login} in org {org_obj.login}",
github_client,
)
is_outside = membership is not None and not membership
info = UserInfo(
login=collaborator.login, name=collaborator.name, email=collaborator.email
)
if repo.organization and is_outside:
outside_collaborators.append(info)
else:
collaborators.append(info)
logger.info(
f"Categorized {len(collaborators)} regular and {len(outside_collaborators)} outside collaborators for {repo.full_name}"
)
return collaborators, outside_collaborators
def form_collaborators_group_id(repository_id: int) -> str:
"""Generate group ID for repository collaborators."""
if not repository_id:
logger.exception("Repository ID is required to generate collaborators group ID")
raise ValueError("Repository ID must be set to generate group ID.")
group_id = f"{repository_id}_collaborators"
return group_id
def form_organization_group_id(organization_id: int) -> str:
"""Generate group ID for organization using organization ID."""
if not organization_id:
logger.exception(
"Organization ID is required to generate organization group ID"
)
raise ValueError("Organization ID must be set to generate group ID.")
group_id = f"{organization_id}_organization"
return group_id
def form_outside_collaborators_group_id(repository_id: int) -> str:
"""Generate group ID for outside collaborators."""
if not repository_id:
logger.exception(
"Repository ID is required to generate outside collaborators group ID"
)
raise ValueError("Repository ID must be set to generate group ID.")
group_id = f"{repository_id}_outside_collaborators"
return group_id
def get_repository_visibility(repo: Repository) -> GitHubVisibility:
"""
Get the visibility of a repository.
Returns GitHubVisibility enum member.
"""
if hasattr(repo, "visibility"):
visibility = repo.visibility
logger.info(
f"Repository {repo.full_name} visibility from attribute: {visibility}"
)
try:
return GitHubVisibility(visibility)
except ValueError:
logger.warning(
f"Unknown visibility '{visibility}' for repo {repo.full_name}, defaulting to private"
)
return GitHubVisibility.PRIVATE
logger.info(f"Repository {repo.full_name} is private")
return GitHubVisibility.PRIVATE
def get_external_access_permission(
repo: Repository, github_client: Github, add_prefix: bool = False
) -> ExternalAccess:
"""
Get the external access permission for a repository.
Uses group-based permissions for efficiency and scalability.
add_prefix: When this method is called during the initial permission sync via the connector,
the group ID isn't prefixed with the source while inserting the document record.
So in that case, set add_prefix to True, allowing the method itself to handle
prefixing. However, when the same method is invoked from doc_sync, our system
already adds the prefix to the group ID while processing the ExternalAccess object.
"""
# We maintain collaborators, and outside collaborators as two separate groups
# instead of adding individual user emails to ExternalAccess.external_user_emails for two reasons:
# 1. Changes in repo collaborators (additions/removals) would require updating all documents.
# 2. Repo permissions can change without updating the repo's updated_at timestamp,
# forcing full permission syncs for all documents every time, which is inefficient.
repo_visibility = get_repository_visibility(repo)
logger.info(
f"Generating ExternalAccess for {repo.full_name}: visibility={repo_visibility.value}"
)
if repo_visibility == GitHubVisibility.PUBLIC:
logger.info(
f"Repository {repo.full_name} is public - allowing access to all users"
)
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=set(),
is_public=True,
)
elif repo_visibility == GitHubVisibility.PRIVATE:
logger.info(
f"Repository {repo.full_name} is private - setting up restricted access"
)
collaborators_group_id = form_collaborators_group_id(repo.id)
outside_collaborators_group_id = form_outside_collaborators_group_id(repo.id)
if add_prefix:
collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=collaborators_group_id,
)
outside_collaborators_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=outside_collaborators_group_id,
)
group_ids = {collaborators_group_id, outside_collaborators_group_id}
team_slugs = fetch_repository_team_slugs(repo, github_client)
if add_prefix:
team_slugs = [
build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=slug,
)
for slug in team_slugs
]
group_ids.update(team_slugs)
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=group_ids,
is_public=False,
)
else:
# Internal repositories - accessible to organization members
logger.info(
f"Repository {repo.full_name} is internal - accessible to org members"
)
org_group_id = form_organization_group_id(repo.organization.id)
if add_prefix:
org_group_id = build_ext_group_name_for_onyx(
source=DocumentSource.GITHUB,
ext_group_name=org_group_id,
)
group_ids = {org_group_id}
logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}")
return ExternalAccess(
external_user_emails=set(),
external_user_group_ids=group_ids,
is_public=False,
)
def get_external_user_group(
repo: Repository, github_client: Github
) -> list[ExternalUserGroup]:
"""
Get the external user group for a repository.
Creates ExternalUserGroup objects with actual user emails for each permission group.
"""
repo_visibility = get_repository_visibility(repo)
logger.info(
f"Generating ExternalUserGroups for {repo.full_name}: visibility={repo_visibility.value}"
)
if repo_visibility == GitHubVisibility.PRIVATE:
logger.info(f"Processing private repository {repo.full_name}")
collaborators, outside_collaborators = (
_get_collaborators_and_outside_collaborators(github_client, repo)
)
teams = _fetch_repository_teams_detailed(repo, github_client)
external_user_groups = []
user_emails = set()
for collab in collaborators:
if collab.email:
user_emails.add(collab.email)
else:
logger.error(f"Collaborator {collab.login} has no email")
if user_emails:
collaborators_group = ExternalUserGroup(
id=form_collaborators_group_id(repo.id),
user_emails=list(user_emails),
)
external_user_groups.append(collaborators_group)
logger.info(f"Created collaborators group with {len(user_emails)} emails")
# Create group for outside collaborators
user_emails = set()
for collab in outside_collaborators:
if collab.email:
user_emails.add(collab.email)
else:
logger.error(f"Outside collaborator {collab.login} has no email")
if user_emails:
outside_collaborators_group = ExternalUserGroup(
id=form_outside_collaborators_group_id(repo.id),
user_emails=list(user_emails),
)
external_user_groups.append(outside_collaborators_group)
logger.info(
f"Created outside collaborators group with {len(user_emails)} emails"
)
# Create groups for teams
for team in teams:
user_emails = set()
for member in team.members:
if member.email:
user_emails.add(member.email)
else:
logger.error(f"Team member {member.login} has no email")
if user_emails:
team_group = ExternalUserGroup(
id=team.slug,
user_emails=list(user_emails),
)
external_user_groups.append(team_group)
logger.info(
f"Created team group {team.name} with {len(user_emails)} emails"
)
logger.info(
f"Created {len(external_user_groups)} ExternalUserGroups for private repository {repo.full_name}"
)
return external_user_groups
if repo_visibility == GitHubVisibility.INTERNAL:
logger.info(f"Processing internal repository {repo.full_name}")
org_group_id = form_organization_group_id(repo.organization.id)
org_members = _fetch_organization_members(
github_client, repo.organization.login
)
user_emails = set()
for member in org_members:
if member.email:
user_emails.add(member.email)
else:
logger.error(f"Org member {member.login} has no email")
org_group = ExternalUserGroup(
id=org_group_id,
user_emails=list(user_emails),
)
logger.info(
f"Created organization group with {len(user_emails)} emails for internal repository {repo.full_name}"
)
return [org_group]
logger.info(f"Repository {repo.full_name} is public - no user groups needed")
return []

View File

@@ -3,6 +3,7 @@ from datetime import datetime
from datetime import timezone
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
@@ -35,6 +36,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -8,6 +8,7 @@ from ee.onyx.external_permissions.google_drive.permission_retrieval import (
get_permissions_by_ids,
)
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.connector import GoogleDriveConnector
@@ -169,6 +170,7 @@ def get_external_access_for_raw_gdrive_file(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -1,6 +1,7 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -17,6 +18,7 @@ JIRA_DOC_SYNC_TAG = "jira_doc_sync"
def jira_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None = None,
) -> Generator[DocExternalAccess, None, None]:
jira_connector = JiraConnector(
@@ -26,7 +28,7 @@ def jira_doc_sync(
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.JIRA,
slim_connector=jira_connector,

View File

@@ -5,6 +5,8 @@ from typing import Protocol
from typing import TYPE_CHECKING
from onyx.context.search.models import InferenceChunk
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
# Avoid circular imports
if TYPE_CHECKING:
@@ -15,14 +17,34 @@ if TYPE_CHECKING:
class FetchAllDocumentsFunction(Protocol):
"""Protocol for a function that fetches all document IDs for a connector credential pair."""
"""Protocol for a function that fetches documents for a connector credential pair.
def __call__(self) -> list[str]:
This protocol defines the interface for functions that retrieve documents
from the database, typically used in permission synchronization workflows.
"""
def __call__(
self,
sort_order: SortOrder | None,
) -> list[DocumentRow]:
"""
Returns a list of document IDs for a connector credential pair.
Fetches documents for a connector credential pair.
"""
...
This is typically used to determine which documents should no longer be
accessible during the document sync process.
class FetchAllDocumentsIdsFunction(Protocol):
"""Protocol for a function that fetches document IDs for a connector credential pair.
This protocol defines the interface for functions that retrieve document IDs
from the database, typically used in permission synchronization workflows.
"""
def __call__(
self,
) -> list[str]:
"""
Fetches document IDs for a connector credential pair.
"""
...
@@ -32,6 +54,7 @@ DocSyncFuncType = Callable[
[
"ConnectorCredentialPair",
FetchAllDocumentsFunction,
FetchAllDocumentsIdsFunction,
Optional["IndexingHeartbeatInterface"],
],
Generator["DocExternalAccess", None, None],

View File

@@ -3,6 +3,7 @@ from collections.abc import Generator
from slack_sdk import WebClient
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
@@ -130,6 +131,7 @@ def _get_slack_document_access(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
"""

View File

@@ -7,12 +7,16 @@ from pydantic import BaseModel
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
from ee.onyx.external_permissions.github.doc_sync import github_doc_sync
from ee.onyx.external_permissions.github.group_sync import github_group_sync
from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync
from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
@@ -20,6 +24,7 @@ from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType
from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType
from ee.onyx.external_permissions.salesforce.postprocessing import (
censor_salesforce_chunks,
@@ -63,6 +68,7 @@ class SyncConfig(BaseModel):
def mock_doc_sync(
cc_pair: "ConnectorCredentialPair",
fetch_all_docs_fn: FetchAllDocumentsFunction,
fetch_all_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: Optional["IndexingHeartbeatInterface"],
) -> Generator["DocExternalAccess", None, None]:
"""Mock doc sync function for testing - returns empty list since permissions are fetched during indexing"""
@@ -117,6 +123,18 @@ _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = {
initial_index_should_sync=False,
),
),
DocumentSource.GITHUB: SyncConfig(
doc_sync_config=DocSyncConfig(
doc_sync_frequency=GITHUB_PERMISSION_DOC_SYNC_FREQUENCY,
doc_sync_func=github_doc_sync,
initial_index_should_sync=True,
),
group_sync_config=GroupSyncConfig(
group_sync_frequency=GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY,
group_sync_func=github_group_sync,
group_sync_is_cc_pair_agnostic=False,
),
),
DocumentSource.SALESFORCE: SyncConfig(
censoring_config=CensoringConfig(
chunk_censoring_func=censor_salesforce_chunks,

View File

@@ -1,6 +1,7 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from ee.onyx.external_permissions.utils import generic_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
@@ -18,6 +19,7 @@ TEAMS_DOC_SYNC_LABEL = "teams_doc_sync"
def teams_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
) -> Generator[DocExternalAccess, None, None]:
teams_connector = TeamsConnector(
@@ -27,7 +29,7 @@ def teams_doc_sync(
yield from generic_doc_sync(
cc_pair=cc_pair,
fetch_all_existing_docs_fn=fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn,
callback=callback,
doc_source=DocumentSource.TEAMS,
slim_connector=teams_connector,

View File

@@ -1,6 +1,6 @@
from collections.abc import Generator
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction
from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction
from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.configs.constants import DocumentSource
@@ -14,7 +14,7 @@ logger = setup_logger()
def generic_doc_sync(
cc_pair: ConnectorCredentialPair,
fetch_all_existing_docs_fn: FetchAllDocumentsFunction,
fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction,
callback: IndexingHeartbeatInterface | None,
doc_source: DocumentSource,
slim_connector: SlimConnector,
@@ -62,9 +62,9 @@ def generic_doc_sync(
)
logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id=}")
existing_doc_ids = set(fetch_all_existing_docs_fn())
existing_doc_ids: list[str] = fetch_all_existing_docs_ids_fn()
missing_doc_ids = existing_doc_ids - newly_fetched_doc_ids
missing_doc_ids = set(existing_doc_ids) - newly_fetched_doc_ids
if not missing_doc_ids:
return

View File

@@ -206,7 +206,7 @@ def _handle_standard_answers(
restate_question_blocks = get_restate_blocks(
msg=query_msg.message,
is_bot_msg=message_info.is_bot_msg,
is_slash_command=message_info.is_slash_command,
)
answer_blocks = build_standard_answer_blocks(

View File

@@ -67,7 +67,7 @@ def generate_chat_messages_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)
@@ -99,7 +99,7 @@ def generate_user_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)

View File

@@ -231,10 +231,7 @@ class DynamicTenantScheduler(PersistentScheduler):
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
if current_tasks != new_tasks:
return False
return True
return current_tasks == new_tasks
@beat_init.connect

View File

@@ -32,7 +32,6 @@ from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
@@ -161,7 +160,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisUserGroup.reset_all(r)
RedisConnectorDelete.reset_all(r)
RedisConnectorPrune.reset_all(r)
RedisConnectorIndex.reset_all(r)
RedisConnectorStop.reset_all(r)
RedisConnectorPermissionSync.reset_all(r)
RedisConnectorExternalGroupSync.reset_all(r)

View File

@@ -1,3 +1,5 @@
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
@@ -8,10 +10,12 @@ import httpx
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
@@ -22,12 +26,14 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:
return {doc.id for doc in doc_batch}
doc_batch: Iterator[list[Document]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
def extract_ids_from_runnable_connector(
@@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_generator = None
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.poll_source(start=start, end=end)
)
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
doc_batch_processing_func = document_batch_to_ids
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
)(lambda x: x)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
return all_connector_doc_ids

View File

@@ -193,12 +193,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()

View File

@@ -2,7 +2,6 @@ import multiprocessing
import os
import time
import traceback
from http import HTTPStatus
from time import sleep
import sentry_sdk
@@ -22,7 +21,7 @@ from onyx.background.celery.tasks.models import SimpleJobResult
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
from onyx.background.indexing.run_docfetching import run_docfetching_entrypoint
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
@@ -34,7 +33,6 @@ from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
@@ -156,7 +154,6 @@ def _docfetching_task(
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector.new_index(search_settings_id)
# TODO: remove all fences, cause all signals to be set in postgres
if redis_connector.delete.fenced:
@@ -214,7 +211,7 @@ def _docfetching_task(
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
run_docfetching_entrypoint(
app,
index_attempt_id,
tenant_id,
@@ -261,7 +258,7 @@ def _docfetching_task(
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
index_attempt_id: int,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
@@ -278,13 +275,11 @@ def process_job_result(
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
# Workaround: check that the total number of batches is set, since this only
# happens when docfetching completed successfully
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt and index_attempt.total_batches is not None:
ignore_exitcode = True
if ignore_exitcode:
@@ -458,9 +453,6 @@ def docfetching_proxy_task(
)
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
@@ -487,7 +479,7 @@ def docfetching_proxy_task(
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
job, result.connector_source, index_attempt_id, log_builder
)
except Exception:
task_logger.exception(
@@ -552,15 +544,20 @@ def docfetching_proxy_task(
# print with exception
try:
with get_session_with_current_tenant() as db_session:
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
attempt = get_index_attempt(db_session, ctx.index_attempt_id)
# only mark failures if not already terminal,
# otherwise we're overwriting potential real stack traces
if attempt and not attempt.status.is_terminal():
failure_reason = (
f"Spawned task exceptioned: exit_code={result.exit_code}"
)
mark_attempt_failed(
ctx.index_attempt_id,
db_session,
failure_reason=failure_reason,
full_exception_trace=result.exception_str,
)
except Exception:
task_logger.exception(
log_builder.build(

View File

@@ -4,7 +4,6 @@ from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from http import HTTPStatus
from typing import Any
from celery import shared_task
@@ -16,6 +15,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -66,6 +67,7 @@ from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.indexing_coordination import CoordinationStatus
from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.search_settings import get_active_search_settings_list
@@ -102,6 +104,7 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
USER_FILE_INDEXING_LIMIT = 100
DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER = 4
def _get_fence_validation_block_expiration() -> int:
@@ -257,7 +260,7 @@ class ConnectorIndexingLogBuilder:
def monitor_indexing_attempt_progress(
attempt: IndexAttempt, tenant_id: str, db_session: Session
attempt: IndexAttempt, tenant_id: str, db_session: Session, task: Task
) -> None:
"""
TODO: rewrite this docstring
@@ -316,7 +319,9 @@ def monitor_indexing_attempt_progress(
# Check task completion using Celery
try:
check_indexing_completion(attempt.id, coordination_status, storage, tenant_id)
check_indexing_completion(
attempt.id, coordination_status, storage, tenant_id, task
)
except Exception as e:
logger.exception(
f"Failed to monitor document processing completion: "
@@ -350,6 +355,7 @@ def check_indexing_completion(
coordination_status: CoordinationStatus,
storage: DocumentBatchStorage,
tenant_id: str,
task: Task,
) -> None:
logger.info(
@@ -376,20 +382,78 @@ def check_indexing_completion(
# Update progress tracking and check for stalls
with get_session_with_current_tenant() as db_session:
# Update progress tracking
stalled_timeout_hours = INDEXING_PROGRESS_TIMEOUT_HOURS
# Index attempts that are waiting between docfetching and
# docprocessing get a generous stalling timeout
if batches_total is not None and batches_processed == 0:
stalled_timeout_hours = (
stalled_timeout_hours * DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER
)
timed_out = not IndexingCoordination.update_progress_tracking(
db_session, index_attempt_id, batches_processed
db_session,
index_attempt_id,
batches_processed,
timeout_hours=stalled_timeout_hours,
)
# Check for stalls (3-6 hour timeout)
if timed_out:
logger.error(
f"Indexing attempt {index_attempt_id} has been indexing for 3-6 hours without progress. "
f"Marking it as failed."
)
mark_attempt_failed(
index_attempt_id, db_session, failure_reason="Stalled indexing"
)
# Check for stalls (3-6 hour timeout). Only applies to in-progress attempts.
attempt = get_index_attempt(db_session, index_attempt_id)
if attempt and timed_out:
if attempt.status == IndexingStatus.IN_PROGRESS:
logger.error(
f"Indexing attempt {index_attempt_id} has been indexing for "
f"{stalled_timeout_hours//2}-{stalled_timeout_hours} hours without progress. "
f"Marking it as failed."
)
mark_attempt_failed(
index_attempt_id, db_session, failure_reason="Stalled indexing"
)
elif (
attempt.status == IndexingStatus.NOT_STARTED and attempt.celery_task_id
):
# Check if the task exists in the celery queue
# This handles the case where Redis dies after task creation but before task execution
redis_celery = task.app.broker_connection().channel().client # type: ignore
task_exists = celery_find_task(
attempt.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
redis_celery,
)
unacked_task_ids = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, redis_celery
)
if not task_exists and attempt.celery_task_id not in unacked_task_ids:
# there is a race condition where the docfetching task has been taken off
# the queues (i.e. started) but the indexing attempt still has a status of
# Not Started because the switch to in progress takes like 0.1 seconds.
# sleep a bit and confirm that the attempt is still not in progress.
time.sleep(1)
attempt = get_index_attempt(db_session, index_attempt_id)
if attempt and attempt.status == IndexingStatus.NOT_STARTED:
logger.error(
f"Task {attempt.celery_task_id} attached to indexing attempt "
f"{index_attempt_id} does not exist in the queue. "
f"Marking indexing attempt as failed."
)
mark_attempt_failed(
index_attempt_id,
db_session,
failure_reason="Task not in queue",
)
else:
logger.info(
f"Indexing attempt {index_attempt_id} is {attempt.status}. 3-6 hours without heartbeat "
"but task is in the queue. Likely underprovisioned docfetching worker."
)
# Update last progress time so we won't time out again for another 3 hours
IndexingCoordination.update_progress_tracking(
db_session,
index_attempt_id,
batches_processed,
force_update_progress=True,
)
# check again on the next check_for_indexing task
# TODO: on the cloud this is currently 25 minutes at most, which
@@ -449,15 +513,6 @@ def check_indexing_completion(
db_session=db_session,
)
# TODO: make it so we don't need this (might already be true)
redis_connector = RedisConnector(
tenant_id, attempt.connector_credential_pair_id
)
redis_connector_index = redis_connector.new_index(
attempt.search_settings_id
)
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
# Clean up FileStore storage (still needed for document batches during transition)
try:
logger.info(f"Cleaning up storage after indexing completion: {storage}")
@@ -811,7 +866,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
for attempt in active_attempts:
try:
monitor_indexing_attempt_progress(attempt, tenant_id, db_session)
monitor_indexing_attempt_progress(
attempt, tenant_id, db_session, self
)
except Exception:
task_logger.exception(f"Error monitoring attempt {attempt.id}")
@@ -1085,12 +1142,8 @@ def _docprocessing_task(
f"Index attempt {index_attempt_id} is not running, status {index_attempt.status}"
)
redis_connector_index = redis_connector.new_index(
index_attempt.search_settings.id
)
cross_batch_db_lock: RedisLock = r.lock(
redis_connector_index.db_lock_key,
redis_connector.db_lock_key(index_attempt.search_settings.id),
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
@@ -1230,17 +1283,6 @@ def _docprocessing_task(
f"attempt={index_attempt_id} "
)
# on failure, signal completion with an error to unblock the watchdog
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt and index_attempt.search_settings:
redis_connector_index = redis_connector.new_index(
index_attempt.search_settings.id
)
redis_connector_index.set_generator_complete(
HTTPStatus.INTERNAL_SERVER_ERROR.value
)
raise
finally:
if per_batch_lock and per_batch_lock.owned():

View File

@@ -47,7 +47,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.tag import delete_orphan_tags__no_commit
@@ -70,9 +69,9 @@ logger = setup_logger()
def _get_pruning_block_expiration() -> int:
"""
Compute the expiration time for the pruning block signal.
Base expiration is 3600 seconds (1 hour), multiplied by the beat multiplier only in MULTI_TENANT mode.
Base expiration is 60 seconds (1 minute), multiplied by the beat multiplier only in MULTI_TENANT mode.
"""
base_expiration = 3600 # seconds
base_expiration = 60 # seconds
if not MULTI_TENANT:
return base_expiration
@@ -145,10 +144,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
last_pruned = cc_pair.connector.time_created
next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq)
if datetime.now(timezone.utc) < next_prune:
return False
return True
return datetime.now(timezone.utc) >= next_prune
@shared_task(
@@ -280,6 +276,9 @@ def try_creating_prune_generator_task(
if not ALLOW_SIMULTANEOUS_PRUNING:
count = redis_connector.prune.get_active_task_count()
if count > 0:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} no simultaneous pruning allowed"
)
return None
LOCK_TIMEOUT = 30
@@ -293,6 +292,9 @@ def try_creating_prune_generator_task(
acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2)
if not acquired:
logger.info(
f"try_creating_prune_generator_task: cc_pair={cc_pair.id} lock not acquired"
)
return None
try:
@@ -516,9 +518,6 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
search_settings = get_current_search_settings(db_session)
redis_connector.new_index(search_settings.id)
callback = PruneCallback(
0,
redis_connector,

View File

@@ -226,8 +226,12 @@ def _check_connector_and_attempt_status(
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
error_str = ""
if index_attempt_loop.error_msg:
error_str = f" Original error: {index_attempt_loop.error_msg}"
raise RuntimeError(
f"Index Attempt is not running, status is {index_attempt_loop.status}"
f"Index Attempt is not running, status is {index_attempt_loop.status}.{error_str}"
)
if index_attempt_loop.celery_task_id is None:
@@ -832,7 +836,7 @@ def _run_indexing(
)
def run_indexing_entrypoint(
def run_docfetching_entrypoint(
app: Celery,
index_attempt_id: int,
tenant_id: str,
@@ -1350,6 +1354,9 @@ def reissue_old_batches(
)
path_info = batch_storage.extract_path_info(batch_id)
if path_info is None:
logger.warning(
f"Could not extract path info from batch {batch_id}, skipping"
)
continue
if path_info.cc_pair_id != cc_pair_id:
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")

View File

@@ -359,6 +359,12 @@ POLL_CONNECTOR_OFFSET = 30 # Minutes overlap between poll windows
# only very select connectors are enabled and admins cannot add other connector types
ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# If set to true, curators can only access and edit assistants that they created
CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS = (
os.environ.get("CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS", "").lower()
== "true"
)
# Some calls to get information on expert users are quite costly especially with rate limiting
# Since experts are not used in the actual user experience, currently it is turned off
# for some connectors

View File

@@ -25,9 +25,32 @@ TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def batched_doc_ids(
checkpoint_connector_generator: CheckpointOutput[CT],
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if len(batch) >= batch_size:
yield batch
batch = set()
if len(batch) > 0:
yield batch
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format.
Wraps a CheckpointOutput generator to give things back in a more digestible format,
specifically for Document outputs.
The connector format is easier for the connector implementor (e.g. it enforces exactly
one new checkpoint is returned AND that the checkpoint is at the end), thus the different
formats.
@@ -131,7 +154,7 @@ class ConnectorRunner(Generic[CT]):
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
if document is not None and isinstance(document, Document):
self.doc_batch.append(document)
if failure is not None:

View File

@@ -222,11 +222,21 @@ class LocalFileConnector(LoadConnector):
"""
Connector that reads files from Postgres and yields Documents, including
embedded image extraction without summarization.
file_locations are S3/Filestore UUIDs
file_names are the names of the files
"""
# Note: file_names is a required parameter, but should not break backwards compatibility.
# If add_file_names migration is not run, old file connector configs will not have file_names.
# This is fine because the configs are not re-used to instantiate the connector.
# file_names is only used for display purposes in the UI and file_locations is used as a fallback.
def __init__(
self,
file_locations: list[Path | str],
file_names: list[
str
], # Must accept this parameter as connector_specific_config is unpacked as args
zip_metadata: dict[str, Any],
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
@@ -260,7 +270,7 @@ class LocalFileConnector(LoadConnector):
logger.warning(f"No file record found for '{file_id}' in PG; skipping.")
continue
metadata = self._get_file_metadata(file_id)
metadata = self._get_file_metadata(file_record.display_name)
file_io = file_store.read_file(file_id=file_id, mode="b")
new_docs = _process_file(
file_id=file_id,
@@ -282,7 +292,9 @@ class LocalFileConnector(LoadConnector):
if __name__ == "__main__":
connector = LocalFileConnector(
file_locations=[os.environ["TEST_FILE"]], zip_metadata={}
file_locations=[os.environ["TEST_FILE"]],
file_names=[os.environ["TEST_FILE"]],
zip_metadata={},
)
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
doc_batches = connector.load_from_state()

View File

@@ -1,5 +1,4 @@
import copy
import time
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
@@ -17,17 +16,22 @@ from github.Issue import Issue
from github.NamedUser import NamedUser
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from github.Requester import Requester
from pydantic import BaseModel
from typing_extensions import override
from onyx.access.models import ExternalAccess
from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL
from onyx.configs.constants import DocumentSource
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.github.models import SerializedRepository
from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception
from onyx.connectors.github.utils import deserialize_repository
from onyx.connectors.github.utils import get_external_access_permission
from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import ConnectorCheckpoint
from onyx.connectors.interfaces import ConnectorFailure
@@ -46,17 +50,7 @@ CURSOR_LOG_FREQUENCY = 50
_MAX_NUM_RATE_LIMIT_RETRIES = 5
ONE_DAY = timedelta(days=1)
def _sleep_after_rate_limit_exception(github_client: Github) -> None:
sleep_time = github_client.get_rate_limit().core.reset.replace(
tzinfo=timezone.utc
) - datetime.now(tz=timezone.utc)
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
time.sleep(sleep_time.seconds)
SLIM_BATCH_SIZE = 100
# Cases
# X (from start) standard run, no fallback to cursor-based pagination
# X (from start) standard run errors, fallback to cursor-based pagination
@@ -72,6 +66,10 @@ def _sleep_after_rate_limit_exception(github_client: Github) -> None:
# checkpoint progress (no infinite loop)
class DocMetadata(BaseModel):
repo: str
def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str:
if "_PaginatedList__nextUrl" in pag_list.__dict__:
return "_PaginatedList__nextUrl"
@@ -190,7 +188,7 @@ def _get_batch_rate_limited(
getattr(obj, "raw_data")
yield from objs
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
sleep_after_rate_limit_exception(github_client)
yield from _get_batch_rate_limited(
git_objs,
page_num,
@@ -232,12 +230,17 @@ def _get_userinfo(user: NamedUser) -> dict[str, str]:
}
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
def _convert_pr_to_document(
pull_request: PullRequest, repo_external_access: ExternalAccess | None
) -> Document:
repo_name = pull_request.base.repo.full_name if pull_request.base else ""
doc_metadata = DocMetadata(repo=repo_name)
return Document(
id=pull_request.html_url,
sections=[
TextSection(link=pull_request.html_url, text=pull_request.body or "")
],
external_access=repo_external_access,
source=DocumentSource.GITHUB,
semantic_identifier=f"{pull_request.number}: {pull_request.title}",
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -248,6 +251,8 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
if pull_request.updated_at
else None
),
# this metadata is used in perm sync
doc_metadata=doc_metadata.model_dump(),
metadata={
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
@@ -301,14 +306,21 @@ def _fetch_issue_comments(issue: Issue) -> str:
return "\nComment: ".join(comment.body for comment in comments)
def _convert_issue_to_document(issue: Issue) -> Document:
def _convert_issue_to_document(
issue: Issue, repo_external_access: ExternalAccess | None
) -> Document:
repo_name = issue.repository.full_name if issue.repository else ""
doc_metadata = DocMetadata(repo=repo_name)
return Document(
id=issue.html_url,
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
source=DocumentSource.GITHUB,
external_access=repo_external_access,
semantic_identifier=f"{issue.number}: {issue.title}",
# updated_at is UTC time but is timezone unaware
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
# this metadata is used in perm sync
doc_metadata=doc_metadata.model_dump(),
metadata={
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
@@ -343,18 +355,6 @@ def _convert_issue_to_document(issue: Issue) -> Document:
)
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)
class GithubConnectorStage(Enum):
START = "start"
PRS = "prs"
@@ -394,7 +394,7 @@ def make_cursor_url_callback(
return cursor_url_callback
class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]):
def __init__(
self,
repo_owner: str,
@@ -423,7 +423,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
)
return None
def _get_github_repo(
def get_github_repo(
self, github_client: Github, attempt_num: int = 0
) -> Repository.Repository:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
@@ -434,10 +434,10 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
sleep_after_rate_limit_exception(github_client)
return self.get_github_repo(github_client, attempt_num + 1)
def _get_github_repos(
def get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
@@ -465,10 +465,10 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
return repos
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repos(github_client, attempt_num + 1)
sleep_after_rate_limit_exception(github_client)
return self.get_github_repos(github_client, attempt_num + 1)
def _get_all_repos(
def get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
@@ -487,8 +487,8 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
user = github_client.get_user(self.repo_owner)
return list(user.get_repos())
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_all_repos(github_client, attempt_num + 1)
sleep_after_rate_limit_exception(github_client)
return self.get_all_repos(github_client, attempt_num + 1)
def _pull_requests_func(
self, repo: Repository.Repository
@@ -509,6 +509,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
checkpoint: GithubConnectorCheckpoint,
start: datetime | None = None,
end: datetime | None = None,
include_permissions: bool = False,
) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]:
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
@@ -521,13 +522,13 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
repos = self.get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
repos = [self.get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
repos = self.get_all_repos(self.github_client)
if not repos:
checkpoint.has_more = False
return checkpoint
@@ -547,28 +548,15 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
if checkpoint.cached_repo is None:
raise ValueError("No repo saved in checkpoint")
# Try to access the requester - different PyGithub versions may use different attribute names
try:
# Try direct access to a known attribute name first
if hasattr(self.github_client, "_requester"):
requester = self.github_client._requester
elif hasattr(self.github_client, "_Github__requester"):
requester = self.github_client._Github__requester
else:
# If we can't find the requester attribute, we need to fall back to recreating the repo
raise AttributeError("Could not find requester attribute")
repo = checkpoint.cached_repo.to_Repository(requester)
except Exception as e:
# If all else fails, re-fetch the repo directly
logger.warning(
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
)
repo_id = checkpoint.cached_repo.id
repo = self.github_client.get_repo(repo_id)
# Deserialize the repository from the checkpoint
repo = deserialize_repository(checkpoint.cached_repo, self.github_client)
cursor_url_callback = make_cursor_url_callback(checkpoint)
repo_external_access: ExternalAccess | None = None
if include_permissions:
repo_external_access = get_external_access_permission(
repo, self.github_client
)
if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS:
logger.info(f"Fetching PRs for repo: {repo.name}")
@@ -603,7 +591,9 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
):
continue
try:
yield _convert_pr_to_document(cast(PullRequest, pr))
yield _convert_pr_to_document(
cast(PullRequest, pr), repo_external_access
)
except Exception as e:
error_msg = f"Error converting PR to document: {e}"
logger.exception(error_msg)
@@ -653,6 +643,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
self.github_client,
)
)
logger.info(f"Fetched {len(issue_batch)} issues for repo: {repo.name}")
checkpoint.curr_page += 1
done_with_issues = False
num_issues = 0
@@ -678,7 +669,7 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
continue
try:
yield _convert_issue_to_document(issue)
yield _convert_issue_to_document(issue, repo_external_access)
except Exception as e:
error_msg = f"Error converting issue to document: {e}"
logger.exception(error_msg)
@@ -715,12 +706,16 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
checkpoint.stage = GithubConnectorStage.PRS
checkpoint.reset()
logger.info(f"{len(checkpoint.cached_repo_ids)} repos remaining")
if checkpoint.cached_repo_ids:
logger.info(
f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})"
)
else:
logger.info("No more repos remaining")
return checkpoint
@override
def load_from_checkpoint(
def _load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
@@ -741,7 +736,32 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
adjusted_start_datetime = epoch
return self._fetch_from_github(
checkpoint, start=adjusted_start_datetime, end=end_datetime
checkpoint,
start=adjusted_start_datetime,
end=end_datetime,
include_permissions=include_permissions,
)
@override
def load_from_checkpoint(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=False
)
@override
def load_from_checkpoint_with_perm_sync(
self,
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
checkpoint: GithubConnectorCheckpoint,
) -> CheckpointOutput[GithubConnectorCheckpoint]:
return self._load_from_checkpoint(
start, end, checkpoint, include_permissions=True
)
def validate_connector_settings(self) -> None:
@@ -775,6 +795,9 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
logger.info(
f"Successfully accessed repository: {self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
@@ -882,7 +905,6 @@ class GithubConnector(CheckpointedConnector[GithubConnectorCheckpoint]):
if __name__ == "__main__":
import os
from onyx.connectors.connector_runner import ConnectorRunner
# Initialize the connector
connector = GithubConnector(
@@ -893,6 +915,12 @@ if __name__ == "__main__":
{"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]}
)
if connector.github_client:
get_external_access_permission(
connector.get_github_repos(connector.github_client).pop(),
connector.github_client,
)
# Create a time range from epoch to now
end_time = datetime.now(timezone.utc)
start_time = datetime.fromtimestamp(0, tz=timezone.utc)

View File

@@ -0,0 +1,17 @@
from typing import Any
from github import Repository
from github.Requester import Requester
from pydantic import BaseModel
class SerializedRepository(BaseModel):
# id is part of the raw_data as well, just pulled out for convenience
id: int
headers: dict[str, str | int]
raw_data: dict[str, Any]
def to_Repository(self, requester: Requester) -> Repository.Repository:
return Repository.Repository(
requester, self.headers, self.raw_data, completed=True
)

View File

@@ -0,0 +1,25 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from github import Github
from onyx.utils.logger import setup_logger
logger = setup_logger()
def sleep_after_rate_limit_exception(github_client: Github) -> None:
"""
Sleep until the GitHub rate limit resets.
Args:
github_client: The GitHub client that hit the rate limit
"""
sleep_time = github_client.get_rate_limit().core.reset.replace(
tzinfo=timezone.utc
) - datetime.now(tz=timezone.utc)
sleep_time += timedelta(minutes=1) # add an extra minute just to be safe
logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.")
time.sleep(sleep_time.total_seconds())

View File

@@ -0,0 +1,63 @@
from collections.abc import Callable
from typing import cast
from github import Github
from github.Repository import Repository
from onyx.access.models import ExternalAccess
from onyx.connectors.github.models import SerializedRepository
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
logger = setup_logger()
def get_external_access_permission(
repo: Repository, github_client: Github
) -> ExternalAccess:
"""
Get the external access permission for a repository.
This functionality requires Enterprise Edition.
"""
# Check if EE is enabled
if not global_version.is_ee_version():
# For the MIT version, return an empty ExternalAccess (private document)
return ExternalAccess.empty()
# Fetch the EE implementation
ee_get_external_access_permission = cast(
Callable[[Repository, Github, bool], ExternalAccess],
fetch_versioned_implementation(
"onyx.external_permissions.github.utils",
"get_external_access_permission",
),
)
return ee_get_external_access_permission(repo, github_client, True)
def deserialize_repository(
cached_repo: SerializedRepository, github_client: Github
) -> Repository:
"""
Deserialize a SerializedRepository back into a Repository object.
"""
# Try to access the requester - different PyGithub versions may use different attribute names
try:
# Try to get the requester using getattr to avoid linter errors
requester = getattr(github_client, "_requester", None)
if requester is None:
requester = getattr(github_client, "_Github__requester", None)
if requester is None:
# If we can't find the requester attribute, we need to fall back to recreating the repo
raise AttributeError("Could not find requester attribute")
return cached_repo.to_Repository(requester)
except Exception as e:
# If all else fails, re-fetch the repo directly
logger.warning(
f"Failed to deserialize repository: {e}. Attempting to re-fetch."
)
repo_id = cached_repo.id
return github_client.get_repo(repo_id)

View File

@@ -183,6 +183,7 @@ class DocumentBase(BaseModel):
# only filled in EE for connectors w/ permission sync enabled
external_access: ExternalAccess | None = None
doc_metadata: dict[str, Any] | None = None
def get_title_for_document_index(
self,

View File

@@ -1,5 +1,6 @@
import csv
import gc
import json
import os
import sys
import tempfile
@@ -28,8 +29,12 @@ from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -38,27 +43,27 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
"Opportunity": {
"Account": "account",
ACCOUNT_OBJECT_TYPE: "account",
"FiscalQuarter": "fiscal_quarter",
"FiscalYear": "fiscal_year",
"IsClosed": "is_closed",
"Name": "name",
NAME_FIELD: "name",
"StageName": "stage_name",
"Type": "type",
"Amount": "amount",
"CloseDate": "close_date",
"Probability": "probability",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
MODIFIED_FIELD: "last_modified_date",
},
"Contact": {
"Account": "account",
ACCOUNT_OBJECT_TYPE: "account",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
MODIFIED_FIELD: "last_modified_date",
},
}
@@ -74,19 +79,77 @@ class SalesforceConnectorContext:
parent_to_child_types: dict[str, set[str]] = {} # map from parent to child types
child_to_parent_types: dict[str, set[str]] = {} # map from child to parent types
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
type_to_queryable_fields: dict[str, list[str]] = {}
type_to_queryable_fields: dict[str, set[str]] = {}
prefix_to_type: dict[str, str] = {} # infer the object type of an id immediately
parent_to_child_relationships: dict[str, set[str]] = (
{}
) # map from parent to child relationships
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = (
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = (
{}
) # map from relationship to queryable fields
parent_child_names_to_relationships: dict[str, str] = {}
def _extract_fields_and_associations_from_config(
config: dict[str, Any], object_type: str
) -> tuple[list[str] | None, dict[str, list[str]]]:
"""
Extract fields and associations for a specific object type from custom config.
Returns:
tuple of (fields_list, associations_dict)
- fields_list: List of fields to query, or None if not specified (use all)
- associations_dict: Dict mapping association names to their config
"""
if object_type not in config:
return None, {}
obj_config = config[object_type]
fields = obj_config.get("fields")
associations = obj_config.get("associations", {})
return fields, associations
def _validate_custom_query_config(config: dict[str, Any]) -> None:
"""
Validate the structure of the custom query configuration.
"""
for object_type, obj_config in config.items():
if not isinstance(obj_config, dict):
raise ValueError(
f"top level object {object_type} must be mapped to a dictionary"
)
# Check if fields is a list when present
if "fields" in obj_config:
if not isinstance(obj_config["fields"], list):
raise ValueError("if fields key exists, value must be a list")
for v in obj_config["fields"]:
if not isinstance(v, str):
raise ValueError(f"if fields list value {v} is not a string")
# Check if associations is a dict when present
if "associations" in obj_config:
if not isinstance(obj_config["associations"], dict):
raise ValueError(
"if associations key exists, value must be a dictionary"
)
for assoc_name, assoc_fields in obj_config["associations"].items():
if not isinstance(assoc_fields, list):
raise ValueError(
f"associations list value {assoc_fields} for key {assoc_name} is not a list"
)
for v in assoc_fields:
if not isinstance(v, str):
raise ValueError(
f"if associations list value {v} is not a string"
)
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
"""Approach outline
@@ -134,14 +197,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
batch_size: int = INDEX_BATCH_SIZE,
requested_objects: list[str] = [],
custom_query_config: str | None = None,
) -> None:
self.batch_size = batch_size
self._sf_client: OnyxSalesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else _DEFAULT_PARENT_OBJECT_TYPES
)
# Validate and store custom query config
if custom_query_config:
config_json = json.loads(custom_query_config)
self.custom_query_config: dict[str, Any] | None = config_json
# If custom query config is provided, use the object types from it
self.parent_object_list = list(config_json.keys())
else:
self.custom_query_config = None
# Use the traditional requested_objects approach
self.parent_object_list = (
[obj.strip().capitalize() for obj in requested_objects]
if requested_objects
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(
self,
@@ -187,7 +261,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
@staticmethod
def _download_object_csvs(
all_types_to_filter: dict[str, bool],
queryable_fields_by_type: dict[str, list[str]],
queryable_fields_by_type: dict[str, set[str]],
directory: str,
sf_client: OnyxSalesforce,
start: SecondsSinceUnixEpoch | None = None,
@@ -325,9 +399,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# all_types.update(child_types.keys())
# # Always want to make sure user is grabbed for permissioning purposes
# all_types.add("User")
# all_types.add(USER_OBJECT_TYPE)
# # Always want to make sure account is grabbed for reference purposes
# all_types.add("Account")
# all_types.add(ACCOUNT_OBJECT_TYPE)
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -351,7 +425,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# all_types.update(child_types)
# # Always want to make sure user is grabbed for permissioning purposes
# all_types.add("User")
# all_types.add(USER_OBJECT_TYPE)
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -364,7 +438,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateDocumentsOutput:
type_to_processed: dict[str, int] = {}
logger.info("_fetch_from_salesforce starting.")
logger.info("_fetch_from_salesforce starting (full sync).")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
@@ -548,7 +622,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateDocumentsOutput:
type_to_processed: dict[str, int] = {}
logger.info("_fetch_from_salesforce starting.")
logger.info("_fetch_from_salesforce starting (delta sync).")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
@@ -677,7 +751,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
try:
last_modified_by_id = record["LastModifiedById"]
user_record = self.sf_client.query_object(
"User", last_modified_by_id, ctx.type_to_queryable_fields
USER_OBJECT_TYPE,
last_modified_by_id,
ctx.type_to_queryable_fields,
)
if user_record:
primary_owner = BasicExpertInfo.from_dict(user_record)
@@ -792,7 +868,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = (
{}
) # for a given object, the fields reference parent objects
type_to_queryable_fields: dict[str, list[str]] = {}
type_to_queryable_fields: dict[str, set[str]] = {}
prefix_to_type: dict[str, str] = {}
parent_to_child_relationships: dict[str, set[str]] = (
@@ -802,15 +878,13 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# relationship keys are formatted as "parent__relationship"
# we have to do this because relationship names are not unique!
# values are a dict of relationship names to a list of queryable fields
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = {}
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = {}
parent_child_names_to_relationships: dict[str, str] = {}
full_sync = False
if start is None and end is None:
full_sync = True
full_sync = start is None and end is None
# Step 1 - make a list of all the types to download (parent + direct child + "User")
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
# prefixes = {}
global_description = sf_client.describe()
@@ -831,16 +905,58 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
for parent_type in parent_types:
# parent_onyx_sf_type = OnyxSalesforceType(parent_type, sf_client)
type_to_queryable_fields[parent_type] = (
sf_client.get_queryable_fields_by_type(parent_type)
)
child_types_working = sf_client.get_children_of_sf_type(parent_type)
logger.debug(
f"Found {len(child_types_working)} child types for {parent_type}"
)
custom_fields: list[str] | None = []
associations_config: dict[str, list[str]] | None = None
# parent_to_child_relationships[parent_type] = child_types_working
# Set queryable fields for parent type
if self.custom_query_config:
custom_fields, associations_config = (
_extract_fields_and_associations_from_config(
self.custom_query_config, parent_type
)
)
custom_fields = custom_fields or []
# Get custom fields for parent type
field_set = set(custom_fields)
# these are expected and used during doc conversion
field_set.add(NAME_FIELD)
field_set.add(MODIFIED_FIELD)
# Use only the specified fields
type_to_queryable_fields[parent_type] = field_set
logger.info(f"Using custom fields for {parent_type}: {field_set}")
else:
# Use all queryable fields
type_to_queryable_fields[parent_type] = (
sf_client.get_queryable_fields_by_type(parent_type)
)
logger.info(f"Using all fields for {parent_type}")
child_types_all = sf_client.get_children_of_sf_type(parent_type)
logger.debug(f"Found {len(child_types_all)} child types for {parent_type}")
logger.debug(f"child types: {child_types_all}")
child_types_working = child_types_all.copy()
if associations_config is not None:
child_types_working = {
k: v for k, v in child_types_all.items() if k in associations_config
}
any_not_found = False
for k in associations_config:
if k not in child_types_working:
any_not_found = True
logger.warning(f"Association {k} not found in {parent_type}")
if any_not_found:
raise RuntimeError(
f"Associations {associations_config} not found in {parent_type} "
f"with child objects {child_types_all.keys()}"
)
parent_to_child_relationships[parent_type] = set()
parent_to_child_types[parent_type] = set()
parent_to_relationship_queryable_fields[parent_type] = {}
for child_type, child_relationship in child_types_working.items():
child_type = cast(str, child_type)
@@ -848,8 +964,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
# map parent name to child name
if parent_type not in parent_to_child_types:
parent_to_child_types[parent_type] = set()
parent_to_child_types[parent_type].add(child_type)
# reverse map child name to parent name
@@ -858,19 +972,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
child_to_parent_types[child_type].add(parent_type)
# map parent name to child relationship
if parent_type not in parent_to_child_relationships:
parent_to_child_relationships[parent_type] = set()
parent_to_child_relationships[parent_type].add(child_relationship)
# map relationship to queryable fields of the target table
queryable_fields = sf_client.get_queryable_fields_by_type(child_type)
if config_fields := (
associations_config and associations_config.get(child_type)
):
field_set = set(config_fields)
# these are expected and used during doc conversion
field_set.add(NAME_FIELD)
field_set.add(MODIFIED_FIELD)
queryable_fields = field_set
else:
queryable_fields = sf_client.get_queryable_fields_by_type(
child_type
)
if child_relationship in parent_to_relationship_queryable_fields:
raise RuntimeError(f"{child_relationship=} already exists")
if parent_type not in parent_to_relationship_queryable_fields:
parent_to_relationship_queryable_fields[parent_type] = {}
parent_to_relationship_queryable_fields[parent_type][
child_relationship
] = queryable_fields
@@ -894,14 +1014,22 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
all_types.update(child_types)
# NOTE(rkuo): should this be an implicit parent type?
all_types.add("User") # Always add User for permissioning purposes
all_types.add("Account") # Always add Account for reference purposes
all_types.add(USER_OBJECT_TYPE) # Always add User for permissioning purposes
all_types.add(ACCOUNT_OBJECT_TYPE) # Always add Account for reference purposes
logger.info(f"All object types: num={len(all_types)} list={all_types}")
# Ensure User and Account have queryable fields if they weren't already processed
essential_types = [USER_OBJECT_TYPE, ACCOUNT_OBJECT_TYPE]
for essential_type in essential_types:
if essential_type not in type_to_queryable_fields:
type_to_queryable_fields[essential_type] = (
sf_client.get_queryable_fields_by_type(essential_type)
)
# 1.1 - Detect all fields in child types which reference a parent type.
# build dicts to detect relationships between parent and child
for child_type in child_types:
for child_type in child_types.union(essential_types):
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
parent_reference_fields = sf_client.get_parent_reference_fields(
child_type, parent_types
@@ -1003,6 +1131,32 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
yield doc_metadata_list
def validate_connector_settings(self) -> None:
"""
Validate that the Salesforce credentials and connector settings are correct.
Specifically checks that we can make an authenticated request to Salesforce.
"""
try:
# Attempt to fetch a small batch of objects (arbitrary endpoint) to verify credentials
self.sf_client.describe()
except Exception as e:
raise ConnectorMissingCredentialError(
"Failed to validate Salesforce credentials. Please check your"
f"credentials and try again. Error: {e}"
)
if self.custom_query_config:
try:
_validate_custom_query_config(self.custom_query_config)
except Exception as e:
raise ConnectorMissingCredentialError(
"Failed to validate Salesforce custom query config. Please check your"
f"config and try again. Error: {e}"
)
logger.info("Salesforce credentials validated successfully.")
# @override
# def load_from_checkpoint(
# self,
@@ -1032,7 +1186,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
if __name__ == "__main__":
connector = SalesforceConnector(requested_objects=["Account"])
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE])
connector.load_credentials(
{

View File

@@ -10,6 +10,8 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
@@ -140,7 +142,7 @@ def _extract_primary_owner(
first_name=user_data.get("FirstName"),
last_name=user_data.get("LastName"),
email=user_data.get("Email"),
display_name=user_data.get("Name"),
display_name=user_data.get(NAME_FIELD),
)
# Check if all fields are None
@@ -166,8 +168,8 @@ def convert_sf_query_result_to_doc(
"""Generates a yieldable Document from query results"""
base_url = f"https://{sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(record["LastModifiedDate"])
extracted_semantic_identifier = record.get("Name", "Unknown Object")
extracted_doc_updated_at = time_str_to_utc(record[MODIFIED_FIELD])
extracted_semantic_identifier = record.get(NAME_FIELD, "Unknown Object")
sections = [_extract_section(record, f"{base_url}/{record_id}")]
for child_record_key, child_record in child_records.items():
@@ -205,8 +207,8 @@ def convert_sf_object_to_doc(
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
base_url = f"https://{sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_doc_updated_at = time_str_to_utc(object_dict[MODIFIED_FIELD])
extracted_semantic_identifier = object_dict.get(NAME_FIELD, "Unknown Object")
sections = [_extract_section(sf_object.data, f"{base_url}/{sf_object.id}")]
for id in sf_db.get_child_ids(sf_object.id):

View File

@@ -60,7 +60,7 @@ class OnyxSalesforce(Salesforce):
return True
for suffix in SALESFORCE_BLACKLISTED_SUFFIXES:
if object_type_lower.endswith(prefix):
if object_type_lower.endswith(suffix):
return True
return False
@@ -112,7 +112,7 @@ class OnyxSalesforce(Salesforce):
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> str:
"""Returns a SOQL query given the object id, type and child relationships.
@@ -148,7 +148,7 @@ class OnyxSalesforce(Salesforce):
self,
object_type: str,
object_id: str,
type_to_queryable_fields: dict[str, list[str]],
type_to_queryable_fields: dict[str, set[str]],
) -> dict[str, Any] | None:
record: dict[str, Any] = {}
@@ -172,7 +172,7 @@ class OnyxSalesforce(Salesforce):
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> dict[str, dict[str, Any]]:
"""There's a limit on the number of subqueries we can put in a single query."""
child_records: dict[str, dict[str, Any]] = {}
@@ -264,10 +264,10 @@ class OnyxSalesforce(Salesforce):
time.sleep(3)
raise
def get_queryable_fields_by_type(self, name: str) -> list[str]:
def get_queryable_fields_by_type(self, name: str) -> set[str]:
object_description = self.describe_type(name)
if object_description is None:
return []
return set()
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
@@ -286,7 +286,7 @@ class OnyxSalesforce(Salesforce):
if field_name:
valid_fields.add(field_name)
return list(valid_fields - field_names_to_remove)
return valid_fields - field_names_to_remove
def get_children_of_sf_type(self, sf_type: str) -> dict[str, str]:
"""Returns a dict of child object names to relationship names.

View File

@@ -14,6 +14,7 @@ from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -54,12 +55,12 @@ def _build_created_date_time_filter_for_salesforce(
def _make_time_filter_for_sf_type(
queryable_fields: list[str],
queryable_fields: set[str],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> str | None:
if "LastModifiedDate" in queryable_fields:
if MODIFIED_FIELD in queryable_fields:
return _build_last_modified_time_filter_for_salesforce(start, end)
if "CreatedDate" in queryable_fields:
@@ -69,14 +70,14 @@ def _make_time_filter_for_sf_type(
def _make_time_filtered_query(
queryable_fields: list[str], sf_type: str, time_filter: str
queryable_fields: set[str], sf_type: str, time_filter: str
) -> str:
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
return query
def get_object_by_id_query(
object_id: str, sf_type: str, queryable_fields: list[str]
object_id: str, sf_type: str, queryable_fields: set[str]
) -> str:
query = (
f"SELECT {', '.join(queryable_fields)} FROM {sf_type} WHERE Id = '{object_id}'"
@@ -193,7 +194,7 @@ def _bulk_retrieve_from_salesforce(
def fetch_all_csvs_in_parallel(
sf_client: Salesforce,
all_types_to_filter: dict[str, bool],
queryable_fields_by_type: dict[str, list[str]],
queryable_fields_by_type: dict[str, set[str]],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
target_dir: str,

View File

@@ -8,11 +8,15 @@ from pathlib import Path
from typing import Any
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
@@ -567,7 +571,7 @@ class OnyxSalesforceSQLite:
uncommitted_rows = 0
# If we're updating User objects, update the email map
if object_type == "User":
if object_type == USER_OBJECT_TYPE:
OnyxSalesforceSQLite._update_user_email_map(cursor)
return updated_ids
@@ -619,7 +623,7 @@ class OnyxSalesforceSQLite:
with self._conn:
cursor = self._conn.cursor()
# Get the object data and account data
if object_type == "Account" or isChild:
if object_type == ACCOUNT_OBJECT_TYPE or isChild:
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
)
@@ -638,7 +642,7 @@ class OnyxSalesforceSQLite:
data = json.loads(result[0][0])
if object_type != "Account":
if object_type != ACCOUNT_OBJECT_TYPE:
# convert any account ids of the relationships back into data fields, with name
for row in result:
@@ -647,14 +651,14 @@ class OnyxSalesforceSQLite:
if len(row) < 3:
continue
if row[1] and row[2] and row[2] == "Account":
if row[1] and row[2] and row[2] == ACCOUNT_OBJECT_TYPE:
data["AccountId"] = row[1]
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?",
(row[1],),
)
account_data = json.loads(cursor.fetchone()[0])
data["Account"] = account_data.get("Name", "")
data[ACCOUNT_OBJECT_TYPE] = account_data.get(NAME_FIELD, "")
return SalesforceObject(id=object_id, type=object_type, data=data)

View File

@@ -2,6 +2,11 @@ import os
from dataclasses import dataclass
from typing import Any
NAME_FIELD = "Name"
MODIFIED_FIELD = "LastModifiedDate"
ACCOUNT_OBJECT_TYPE = "Account"
USER_OBJECT_TYPE = "User"
@dataclass
class SalesforceObject:

View File

@@ -1,13 +1,17 @@
import html
import io
import os
import re
import time
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from urllib.parse import unquote
import msal # type: ignore
import requests
from office365.graph_client import GraphClient # type: ignore
from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore
from office365.onedrive.sites.site import Site # type: ignore
@@ -33,6 +37,10 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
ASPX_EXTENSION = ".aspx"
REQUEST_TIMEOUT = 10
class SiteDescriptor(BaseModel):
"""Data class for storing SharePoint site information.
@@ -136,11 +144,156 @@ def _convert_driveitem_to_document(
return doc
def _convert_sitepage_to_document(
site_page: dict[str, Any], site_name: str | None
) -> Document:
"""Convert a SharePoint site page to a Document object."""
# Extract text content from the site page
page_text = ""
# Get title and description
title = cast(str, site_page.get("title", ""))
description = cast(str, site_page.get("description", ""))
# Build the text content
if title:
page_text += f"# {title}\n\n"
if description:
page_text += f"{description}\n\n"
# Extract content from canvas layout if available
canvas_layout = site_page.get("canvasLayout", {})
if canvas_layout:
horizontal_sections = canvas_layout.get("horizontalSections", [])
for section in horizontal_sections:
columns = section.get("columns", [])
for column in columns:
webparts = column.get("webparts", [])
for webpart in webparts:
# Extract text from different types of webparts
webpart_type = webpart.get("@odata.type", "")
# Extract text from text webparts
if webpart_type == "#microsoft.graph.textWebPart":
inner_html = webpart.get("innerHtml", "")
if inner_html:
# Basic HTML to text conversion
# Remove HTML tags but preserve some structure
text_content = re.sub(r"<br\s*/?>", "\n", inner_html)
text_content = re.sub(r"<li>", "", text_content)
text_content = re.sub(r"</li>", "\n", text_content)
text_content = re.sub(
r"<h[1-6][^>]*>", "\n## ", text_content
)
text_content = re.sub(r"</h[1-6]>", "\n", text_content)
text_content = re.sub(r"<p[^>]*>", "\n", text_content)
text_content = re.sub(r"</p>", "\n", text_content)
text_content = re.sub(r"<[^>]+>", "", text_content)
# Decode HTML entities
text_content = html.unescape(text_content)
# Clean up extra whitespace
text_content = re.sub(
r"\n\s*\n", "\n\n", text_content
).strip()
if text_content:
page_text += f"{text_content}\n\n"
# Extract text from standard webparts
elif webpart_type == "#microsoft.graph.standardWebPart":
data = webpart.get("data", {})
# Extract from serverProcessedContent
server_content = data.get("serverProcessedContent", {})
searchable_texts = server_content.get(
"searchablePlainTexts", []
)
for text_item in searchable_texts:
if isinstance(text_item, dict):
key = text_item.get("key", "")
value = text_item.get("value", "")
if value:
# Add context based on key
if key == "title":
page_text += f"## {value}\n\n"
else:
page_text += f"{value}\n\n"
# Extract description if available
description = data.get("description", "")
if description:
page_text += f"{description}\n\n"
# Extract title if available
webpart_title = data.get("title", "")
if webpart_title and webpart_title != description:
page_text += f"## {webpart_title}\n\n"
page_text = page_text.strip()
# If no content extracted, use the title as fallback
if not page_text and title:
page_text = title
# Parse creation and modification info
created_datetime = site_page.get("createdDateTime")
if created_datetime:
if isinstance(created_datetime, str):
created_datetime = datetime.fromisoformat(
created_datetime.replace("Z", "+00:00")
)
elif not created_datetime.tzinfo:
created_datetime = created_datetime.replace(tzinfo=timezone.utc)
last_modified_datetime = site_page.get("lastModifiedDateTime")
if last_modified_datetime:
if isinstance(last_modified_datetime, str):
last_modified_datetime = datetime.fromisoformat(
last_modified_datetime.replace("Z", "+00:00")
)
elif not last_modified_datetime.tzinfo:
last_modified_datetime = last_modified_datetime.replace(tzinfo=timezone.utc)
# Extract owner information
primary_owners = []
created_by = site_page.get("createdBy", {}).get("user", {})
if created_by.get("displayName"):
primary_owners.append(
BasicExpertInfo(
display_name=created_by.get("displayName"),
email=created_by.get("email", ""),
)
)
web_url = site_page["webUrl"]
semantic_identifier = cast(str, site_page.get("name", title))
if semantic_identifier.endswith(ASPX_EXTENSION):
semantic_identifier = semantic_identifier[: -len(ASPX_EXTENSION)]
doc = Document(
id=site_page["id"],
sections=[TextSection(link=web_url, text=page_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=semantic_identifier,
doc_updated_at=last_modified_datetime or created_datetime,
primary_owners=primary_owners,
metadata=(
{
"site": site_name,
}
if site_name
else {}
),
)
return doc
class SharepointConnector(LoadConnector, PollConnector):
def __init__(
self,
batch_size: int = INDEX_BATCH_SIZE,
sites: list[str] = [],
include_site_pages: bool = True,
) -> None:
self.batch_size = batch_size
self._graph_client: GraphClient | None = None
@@ -148,6 +301,7 @@ class SharepointConnector(LoadConnector, PollConnector):
sites
)
self.msal_app: msal.ConfidentialClientApplication | None = None
self.include_site_pages = include_site_pages
@property
def graph_client(self) -> GraphClient:
@@ -284,7 +438,7 @@ class SharepointConnector(LoadConnector, PollConnector):
except Exception as e:
# Some drives might not be accessible
logger.warning(f"Failed to process drive: {str(e)}")
logger.warning(f"Failed to process drive '{drive.name}': {str(e)}")
except Exception as e:
err_str = str(e)
@@ -327,6 +481,74 @@ class SharepointConnector(LoadConnector, PollConnector):
]
return site_descriptors
def _fetch_site_pages(
self,
site_descriptor: SiteDescriptor,
start: datetime | None = None,
end: datetime | None = None,
) -> list[dict[str, Any]]:
"""Fetch SharePoint site pages (.aspx files) using the SharePoint Pages API."""
# Get the site to extract the site ID
site = self.graph_client.sites.get_by_url(site_descriptor.url)
site.execute_query() # Execute the query to actually fetch the data
site_id = site.id
# Get the token acquisition function from the GraphClient
token_data = self._acquire_token()
access_token = token_data.get("access_token")
if not access_token:
raise RuntimeError("Failed to acquire access token")
# Construct the SharePoint Pages API endpoint
# Using API directly, since the Graph Client doesn't support the Pages API
pages_endpoint = f"https://graph.microsoft.com/v1.0/sites/{site_id}/pages/microsoft.graph.sitePage"
headers = {
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
# Add expand parameter to get canvas layout content
params = {"$expand": "canvasLayout"}
response = requests.get(
pages_endpoint, headers=headers, params=params, timeout=REQUEST_TIMEOUT
)
response.raise_for_status()
pages_data = response.json()
all_pages = pages_data.get("value", [])
# Handle pagination if there are more pages
while "@odata.nextLink" in pages_data:
next_url = pages_data["@odata.nextLink"]
response = requests.get(next_url, headers=headers, timeout=REQUEST_TIMEOUT)
response.raise_for_status()
pages_data = response.json()
all_pages.extend(pages_data.get("value", []))
logger.debug(f"Found {len(all_pages)} site pages in {site_descriptor.url}")
# Filter pages based on time window if specified
if start is not None or end is not None:
filtered_pages = []
for page in all_pages:
page_modified = page.get("lastModifiedDateTime")
if page_modified:
if isinstance(page_modified, str):
page_modified = datetime.fromisoformat(
page_modified.replace("Z", "+00:00")
)
if start is not None and page_modified < start:
continue
if end is not None and page_modified > end:
continue
filtered_pages.append(page)
all_pages = filtered_pages
return all_pages
def _fetch_from_sharepoint(
self, start: datetime | None = None, end: datetime | None = None
) -> GenerateDocumentsOutput:
@@ -335,6 +557,7 @@ class SharepointConnector(LoadConnector, PollConnector):
# goes over all urls, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for site_descriptor in site_descriptors:
# Fetch regular documents from document libraries
driveitems = self._fetch_driveitems(site_descriptor, start=start, end=end)
for driveitem, drive_name in driveitems:
logger.debug(f"Processing: {driveitem.web_url}")
@@ -347,8 +570,47 @@ class SharepointConnector(LoadConnector, PollConnector):
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Fetch SharePoint site pages (.aspx files)
# Only fetch site pages if a folder is not specified since this processing
# happens at a site-wide level + specifying a folder implies that the
# user probably isn't looking for site pages
specified_path = (
site_descriptor.folder_path is not None
or site_descriptor.drive_name is not None
)
if self.include_site_pages and not specified_path:
site_pages = self._fetch_site_pages(
site_descriptor, start=start, end=end
)
for site_page in site_pages:
logger.debug(
f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}"
)
doc_batch.append(
_convert_sitepage_to_document(
site_page, site_descriptor.drive_name
)
)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
yield doc_batch
def _acquire_token(self) -> dict[str, Any]:
"""
Acquire token via MSAL
"""
if self.msal_app is None:
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
sp_client_id = credentials["sp_client_id"]
sp_client_secret = credentials["sp_client_secret"]
@@ -360,20 +622,7 @@ class SharepointConnector(LoadConnector, PollConnector):
client_id=sp_client_id,
client_credential=sp_client_secret,
)
def _acquire_token_func() -> dict[str, Any]:
"""
Acquire token via MSAL
"""
if self.msal_app is None:
raise RuntimeError("MSAL app is not initialized")
token = self.msal_app.acquire_token_for_client(
scopes=["https://graph.microsoft.com/.default"]
)
return token
self._graph_client = GraphClient(_acquire_token_func)
self._graph_client = GraphClient(self._acquire_token)
return None
def load_from_state(self) -> GenerateDocumentsOutput:

View File

@@ -48,7 +48,9 @@ from onyx.db.relationships import (
delete_from_kg_relationships_extraction_staging__no_commit,
)
from onyx.db.tag import delete_document_tags_for_documents__no_commit
from onyx.db.utils import DocumentRow
from onyx.db.utils import model_to_dict
from onyx.db.utils import SortOrder
from onyx.document_index.interfaces import DocumentMetadata
from onyx.kg.models import KGStage
from onyx.server.documents.models import ConnectorCredentialPairIdentifier
@@ -150,7 +152,7 @@ def get_documents_for_cc_pair(
def get_document_ids_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
db_session: Session, connector_id: int, credential_id: int
) -> list[str]:
doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
@@ -161,6 +163,47 @@ def get_document_ids_for_connector_credential_pair(
return list(db_session.execute(doc_ids_stmt).scalars().all())
def get_documents_for_connector_credential_pair_limited_columns(
db_session: Session,
connector_id: int,
credential_id: int,
sort_order: SortOrder | None = None,
) -> Sequence[DocumentRow]:
doc_ids_subquery = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
)
doc_ids_subquery = doc_ids_subquery.join(
DbDocument, DocumentByConnectorCredentialPair.id == DbDocument.id
)
stmt = select(
DbDocument.id, DbDocument.doc_metadata, DbDocument.external_user_group_ids
)
stmt = stmt.where(DbDocument.id.in_(doc_ids_subquery))
if sort_order == SortOrder.ASC:
stmt = stmt.order_by(DbDocument.last_modified.asc())
elif sort_order == SortOrder.DESC:
stmt = stmt.order_by(DbDocument.last_modified.desc())
rows = db_session.execute(stmt).mappings().all()
doc_rows: list[DocumentRow] = []
for row in rows:
doc_row = DocumentRow(
id=row.id,
doc_metadata=row.doc_metadata,
external_user_group_ids=row.external_user_group_ids,
)
doc_rows.append(doc_row)
return doc_rows
def get_documents_for_connector_credential_pair(
db_session: Session, connector_id: int, credential_id: int, limit: int | None = None
) -> Sequence[DbDocument]:
@@ -370,6 +413,7 @@ def upsert_documents(
if doc.external_access
else {}
),
doc_metadata=doc.doc_metadata,
)
)
for doc in seen_documents.values()
@@ -389,6 +433,7 @@ def upsert_documents(
"external_user_emails": insert_stmt.excluded.external_user_emails,
"external_user_group_ids": insert_stmt.excluded.external_user_group_ids,
"is_public": insert_stmt.excluded.is_public,
"doc_metadata": insert_stmt.excluded.doc_metadata,
},
)
db_session.execute(on_conflict_stmt)
@@ -1031,7 +1076,7 @@ def reset_all_document_kg_stages(db_session: Session) -> int:
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
def update_document_kg_stages(
@@ -1054,7 +1099,7 @@ def update_document_kg_stages(
result = db_session.execute(stmt)
# The hasattr check is needed for type checking, even though rowcount
# is guaranteed to exist at runtime for UPDATE operations
return result.rowcount if hasattr(result, "rowcount") else 0
return result.rowcount if hasattr(result, "rowcount") else 0 # type: ignore
def get_skipped_kg_documents(db_session: Session) -> list[str]:

View File

@@ -12,6 +12,7 @@ from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Session
import onyx.db.document as dbdocument
from onyx.db.entity_type import UNGROUNDED_SOURCE_NAME
from onyx.db.models import Document
from onyx.db.models import KGEntity
from onyx.db.models import KGEntityExtractionStaging
@@ -328,7 +329,13 @@ def get_entity_stats_by_grounded_source_name(
.group_by(KGEntityType.grounded_source_name)
.all()
)
# `row.grounded_source_name` is NULLABLE in the database schema.
# Thus, for all "ungrounded" entity-types, we use a default name.
return {
row.grounded_source_name: (row.last_updated, row.entities_count)
(row.grounded_source_name or UNGROUNDED_SOURCE_NAME): (
row.last_updated,
row.entities_count,
)
for row in results
}

View File

@@ -11,7 +11,7 @@ from onyx.kg.models import KGAttributeEntityOption
from onyx.server.kg.models import EntityType
_UNGROUNDED_SOURCE_NAME = "Ungrounded"
UNGROUNDED_SOURCE_NAME = "Ungrounded"
def get_entity_types_with_grounded_source_name(
@@ -87,7 +87,7 @@ def get_configured_entity_types(db_session: Session) -> dict[str, list[KGEntityT
et_map = defaultdict(list)
for et in ets:
key = et.grounded_source_name or _UNGROUNDED_SOURCE_NAME
key = et.grounded_source_name or UNGROUNDED_SOURCE_NAME
et_map[key].append(et)
return et_map

View File

@@ -267,6 +267,7 @@ class IndexingCoordination:
index_attempt_id: int,
current_batches_completed: int,
timeout_hours: int = INDEXING_PROGRESS_TIMEOUT_HOURS,
force_update_progress: bool = False,
) -> bool:
"""
Update progress tracking for stall detection.
@@ -281,7 +282,8 @@ class IndexingCoordination:
current_time = get_db_current_time(db_session)
# No progress - check if this is the first time tracking
if attempt.last_progress_time is None:
# or if the caller wants to simulate guaranteed progress
if attempt.last_progress_time is None or force_update_progress:
# First time tracking - initialize
attempt.last_progress_time = current_time
attempt.last_batches_completed_count = current_batches_completed

View File

@@ -608,6 +608,10 @@ class Document(Base):
retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)
doc_metadata: Mapped[dict[str, Any] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
)
tags = relationship(
"Tag",
secondary=Document__Tag.__table__,

View File

@@ -15,6 +15,7 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.chat_configs import BING_API_KEY
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
@@ -96,6 +97,14 @@ def _add_user_filters(
where_clause = Persona.is_public == True # noqa: E712
return stmt.where(where_clause)
# If curator ownership restriction is enabled, curators can only access their own assistants
if CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS and user.role in [
UserRole.CURATOR,
UserRole.GLOBAL_CURATOR,
]:
where_clause = (Persona.user_id == user.id) | (Persona.user_id.is_(None))
return stmt.where(where_clause)
where_clause = User__UserGroup.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UserGroup.is_curator == True # noqa: E712

View File

@@ -115,6 +115,7 @@ def create_file_connector_credential(
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,

View File

@@ -1,7 +1,9 @@
from enum import Enum
from typing import Any
from psycopg2 import errorcodes
from psycopg2 import OperationalError
from pydantic import BaseModel
from sqlalchemy import inspect
from onyx.db.models import Base
@@ -27,3 +29,14 @@ def is_retryable_sqlalchemy_error(exc: BaseException) -> bool:
pgcode = getattr(getattr(exc, "orig", None), "pgcode", None)
return pgcode in RETRYABLE_PG_CODES
return False
class DocumentRow(BaseModel):
id: str
doc_metadata: dict[str, Any]
external_user_group_ids: list[str]
class SortOrder(str, Enum):
ASC = "asc"
DESC = "desc"

View File

@@ -91,6 +91,7 @@ class DocumentMetadata:
from_ingestion_api: bool = False
external_access: ExternalAccess | None = None
doc_metadata: dict[str, Any] | None = None
@dataclass

View File

@@ -17,11 +17,11 @@ from typing import NamedTuple
from zipfile import BadZipFile
import chardet
import docx # type: ignore
import openpyxl # type: ignore
import pptx # type: ignore
from docx import Document as DocxDocument
from fastapi import UploadFile
from markitdown import FileConversionException
from markitdown import MarkItDown
from markitdown import UnsupportedFormatException
from PIL import Image
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
@@ -83,11 +83,6 @@ IMAGE_MEDIA_TYPES = [
"image/webp",
]
KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
]
class OnyxExtensionType(IntFlag):
Plain = auto()
@@ -149,6 +144,13 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
)
def to_bytesio(stream: IO[bytes]) -> BytesIO:
if isinstance(stream, BytesIO):
return stream
data = stream.read() # consumes the stream!
return BytesIO(data)
def load_files_from_zip(
zip_file_io: IO,
ignore_macos_resource_fork_files: bool = True,
@@ -305,19 +307,38 @@ def read_pdf_file(
return "", metadata, []
def extract_docx_images(docx_bytes: IO[Any]) -> list[tuple[bytes, str]]:
"""
Given the bytes of a docx file, extract all the images.
Returns a list of tuples (image_bytes, image_name).
"""
out = []
try:
with zipfile.ZipFile(docx_bytes) as z:
for name in z.namelist():
if name.startswith("word/media/"):
out.append((z.read(name), name.split("/")[-1]))
except Exception:
logger.exception("Failed to extract all docx images")
return out
def docx_to_text_and_images(
file: IO[Any], file_name: str = ""
) -> tuple[str, Sequence[tuple[bytes, str]]]:
"""
Extract text from a docx. If embed_images=True, also extract inline images.
Extract text from a docx.
Return (text_content, list_of_images).
"""
paragraphs = []
embedded_images: list[tuple[bytes, str]] = []
md = MarkItDown(enable_plugins=False)
try:
doc = docx.Document(file)
except BadZipFile as e:
doc = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
logger.warning(
f"Failed to extract docx {file_name or 'docx file'}: {e}. Attempting to read as text file."
)
@@ -330,86 +351,44 @@ def docx_to_text_and_images(
)
return text_content_raw or "", []
# Grab text from paragraphs
for paragraph in doc.paragraphs:
paragraphs.append(paragraph.text)
# Reset position so we can re-load the doc (python-docx has read the stream)
# Note: if python-docx has fully consumed the stream, you may need to open it again from memory.
# For large docs, a more robust approach is needed.
# This is a simplified example.
for rel_id, rel in doc.part.rels.items():
if "image" in rel.reltype:
# image is typically in rel.target_part.blob
image_bytes = rel.target_part.blob
image_name = rel.target_part.partname
# store
embedded_images.append((image_bytes, os.path.basename(str(image_name))))
text_content = "\n".join(paragraphs)
return text_content, embedded_images
file.seek(0)
return doc.markdown, extract_docx_images(to_bytesio(file))
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
md = MarkItDown(enable_plugins=False)
try:
presentation = pptx.Presentation(file)
except BadZipFile as e:
presentation = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
error_str = f"Failed to extract text from {file_name or 'pptx file'}: {e}"
logger.warning(error_str)
return ""
text_content = []
for slide_number, slide in enumerate(presentation.slides, start=1):
slide_text = f"\nSlide {slide_number}:\n"
for shape in slide.shapes:
if hasattr(shape, "text"):
slide_text += shape.text + "\n"
text_content.append(slide_text)
return TEXT_SECTION_SEPARATOR.join(text_content)
return presentation.markdown
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
md = MarkItDown(enable_plugins=False)
try:
workbook = openpyxl.load_workbook(file, read_only=True)
except BadZipFile as e:
workbook = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
if file_name.startswith("~"):
logger.debug(error_str + " (this is expected for files with ~)")
else:
logger.warning(error_str)
return ""
except Exception as e:
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
logger.error(
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return ""
raise e
text_content = []
for sheet in workbook.worksheets:
rows = []
num_empty_consecutive_rows = 0
for row in sheet.iter_rows(min_row=1, values_only=True):
row_str = ",".join(str(cell or "") for cell in row)
# Only add the row if there are any values in the cells
if len(row_str) >= len(row):
rows.append(row_str)
num_empty_consecutive_rows = 0
else:
num_empty_consecutive_rows += 1
if num_empty_consecutive_rows > 100:
# handle massive excel sheets with mostly empty cells
logger.warning(
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
" skipping rest of file"
)
break
sheet_str = "\n".join(rows)
text_content.append(sheet_str)
return TEXT_SECTION_SEPARATOR.join(text_content)
return workbook.markdown
def eml_to_text(file: IO[Any]) -> str:
@@ -462,9 +441,9 @@ def extract_file_text(
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
".docx": lambda f: docx_to_text_and_images(f)[0], # no images
".pptx": pptx_to_text,
".xlsx": xlsx_to_text,
".docx": lambda f: docx_to_text_and_images(f, file_name)[0], # no images
".pptx": lambda f: pptx_to_text(f, file_name),
".xlsx": lambda f: xlsx_to_text(f, file_name),
".eml": eml_to_text,
".epub": epub_to_text,
".html": parse_html_page_basic,
@@ -522,23 +501,28 @@ def extract_text_and_images(
Primary new function for the updated connector.
Returns structured extraction result with text content, embedded images, and metadata.
"""
file.seek(0)
try:
# Attempt unstructured if env var is set
if get_unstructured_api_key():
# If the user doesn't want embedded images, unstructured is fine
file.seek(0)
if get_unstructured_api_key():
try:
text_content = unstructured_to_text(file, file_name)
return ExtractionResult(
text_content=text_content, embedded_images=[], metadata={}
)
except Exception as e:
logger.error(
f"Failed to process with Unstructured: {str(e)}. "
"Falling back to normal processing."
)
file.seek(0) # Reset file pointer just in case
# Default processing
try:
extension = get_file_ext(file_name)
# docx example for embedded images
if extension == ".docx":
file.seek(0)
text_content, images = docx_to_text_and_images(file)
text_content, images = docx_to_text_and_images(file, file_name)
return ExtractionResult(
text_content=text_content, embedded_images=images, metadata={}
)
@@ -546,7 +530,6 @@ def extract_text_and_images(
# PDF example: we do not show complicated PDF image extraction here
# so we simply extract text for now and skip images.
if extension == ".pdf":
file.seek(0)
text_content, pdf_metadata, images = read_pdf_file(
file,
pdf_pass,
@@ -559,7 +542,6 @@ def extract_text_and_images(
# For PPTX, XLSX, EML, etc., we do not show embedded image logic here.
# You can do something similar to docx if needed.
if extension == ".pptx":
file.seek(0)
return ExtractionResult(
text_content=pptx_to_text(file, file_name=file_name),
embedded_images=[],
@@ -567,7 +549,6 @@ def extract_text_and_images(
)
if extension == ".xlsx":
file.seek(0)
return ExtractionResult(
text_content=xlsx_to_text(file, file_name=file_name),
embedded_images=[],
@@ -575,19 +556,16 @@ def extract_text_and_images(
)
if extension == ".eml":
file.seek(0)
return ExtractionResult(
text_content=eml_to_text(file), embedded_images=[], metadata={}
)
if extension == ".epub":
file.seek(0)
return ExtractionResult(
text_content=epub_to_text(file), embedded_images=[], metadata={}
)
if extension == ".html":
file.seek(0)
return ExtractionResult(
text_content=parse_html_page_basic(file),
embedded_images=[],
@@ -596,7 +574,6 @@ def extract_text_and_images(
# If we reach here and it's a recognized text extension
if is_text_file_extension(file_name):
file.seek(0)
encoding = detect_encoding(file)
text_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
@@ -609,7 +586,6 @@ def extract_text_and_images(
# If it's an image file or something else, we do not parse embedded images from them
# just return empty text
file.seek(0)
return ExtractionResult(text_content="", embedded_images=[], metadata={})
except Exception as e:

View File

@@ -32,9 +32,11 @@ def is_valid_image_type(mime_type: str) -> bool:
Returns:
True if the MIME type is a valid image type, False otherwise
"""
if not mime_type:
return False
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
return (
bool(mime_type)
and mime_type.startswith("image/")
and mime_type not in EXCLUDED_IMAGE_TYPES
)
def is_supported_by_vision_llm(mime_type: str) -> bool:

View File

@@ -1,6 +1,7 @@
import re
from copy import copy
from dataclasses import dataclass
from io import BytesIO
from typing import IO
import bs4
@@ -161,7 +162,7 @@ def format_document_soup(
return strip_excessive_newlines_and_spaces(text)
def parse_html_page_basic(text: str | IO[bytes]) -> str:
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
soup = bs4.BeautifulSoup(text, "html.parser")
return format_document_soup(soup)

View File

@@ -196,6 +196,9 @@ class FileStoreDocumentBatchStorage(DocumentBatchStorage):
for batch_file_name in batch_names:
path_info = self.extract_path_info(batch_file_name)
if path_info is None:
logger.warning(
f"Could not extract path info from batch file: {batch_file_name}"
)
continue
new_batch_file_name = self._get_batch_file_name(path_info.batch_num)
self.file_store.change_file_id(batch_file_name, new_batch_file_name)

View File

@@ -49,11 +49,10 @@ def sanitize_s3_key_name(file_name: str) -> str:
# Characters to avoid completely (replace with underscore)
# These are characters that AWS recommends avoiding
avoid_chars = r'[\\{}^%`\[\]"<>#|~]'
avoid_chars = r'[\\{}^%`\[\]"<>#|~/]'
# Replace avoided characters with underscore
sanitized = re.sub(avoid_chars, "_", file_name)
# Characters that might require special handling but are allowed
# We'll URL encode these to be safe
special_chars = r"[&$@=;:+,?\s]"
@@ -81,6 +80,9 @@ def sanitize_s3_key_name(file_name: str) -> str:
# Remove any trailing periods to avoid download issues
sanitized = sanitized.rstrip(".")
# Remove multiple separators
sanitized = re.sub(r"[-_]{2,}", "-", sanitized)
# If sanitization resulted in empty string, use a default
if not sanitized:
sanitized = "sanitized_file"

View File

@@ -46,7 +46,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
# Use a separate session to avoid committing the caller's transaction
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))

View File

@@ -142,6 +142,7 @@ def _upsert_documents_in_db(
secondary_owners=get_experts_stores_representations(doc.secondary_owners),
from_ingestion_api=doc.from_ingestion_api,
external_access=doc.external_access,
doc_metadata=doc.doc_metadata,
)
document_metadata_list.append(db_doc_metadata)
@@ -866,31 +867,27 @@ def index_doc_batch(
user_file_id_to_raw_text: dict[int, str] = {}
for document_id in updatable_ids:
# Only calculate token counts for documents that have a user file ID
if (
document_id in doc_id_to_user_file_id
and doc_id_to_user_file_id[document_id] is not None
):
user_file_id = doc_id_to_user_file_id[document_id]
if not user_file_id:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content))
if llm_tokenizer
else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
user_file_id = doc_id_to_user_file_id.get(document_id)
if user_file_id is None:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.

View File

@@ -313,14 +313,14 @@ class DefaultMultiLLM(LLM):
self._model_kwargs = model_kwargs
def log_model_configs(self) -> None:
logger.debug(f"Config: {self.config}")
def _safe_model_config(self) -> dict:
dump = self.config.model_dump()
dump["api_key"] = mask_string(dump.get("api_key", ""))
return dump
def log_model_configs(self) -> None:
logger.debug(f"Config: {self._safe_model_config()}")
def _record_call(self, prompt: LanguageModelInput) -> None:
if self._long_term_logger:
self._long_term_logger.record(
@@ -397,7 +397,11 @@ class DefaultMultiLLM(LLM):
# streaming choice
stream=stream,
# model params
temperature=self._temperature,
temperature=(
1
if self.config.model_name in ["gpt-5", "gpt-5-mini", "gpt-5-nano"]
else self._temperature
),
timeout=timeout_override or self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified

View File

@@ -47,6 +47,9 @@ class WellKnownLLMProviderDescriptor(BaseModel):
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"o4-mini",
"o3-mini",
"o1-mini",
@@ -73,7 +76,14 @@ OPEN_AI_MODEL_NAMES = [
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]
OPEN_AI_VISIBLE_MODEL_NAMES = ["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"]
OPEN_AI_VISIBLE_MODEL_NAMES = [
"gpt-5",
"gpt-5-mini",
"o1",
"o3-mini",
"gpt-4o",
"gpt-4o-mini",
]
BEDROCK_PROVIDER_NAME = "bedrock"
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named

View File

@@ -151,7 +151,7 @@ def _build_ephemeral_publication_block(
email=message_info.email,
sender_id=message_info.sender_id,
thread_messages=[],
is_bot_msg=message_info.is_bot_msg,
is_slash_command=message_info.is_slash_command,
is_bot_dm=message_info.is_bot_dm,
thread_to_respond=respond_ts,
)
@@ -225,10 +225,10 @@ def _build_doc_feedback_block(
def get_restate_blocks(
msg: str,
is_bot_msg: bool,
is_slash_command: bool,
) -> list[Block]:
# Only the slash command needs this context because the user doesn't see their own input
if not is_bot_msg:
if not is_slash_command:
return []
return [
@@ -576,7 +576,7 @@ def build_slack_response_blocks(
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
if not skip_restated_question:
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
message_info.thread_messages[-1].message, message_info.is_slash_command
)
else:
restate_question_block = []

View File

@@ -177,7 +177,7 @@ def handle_generate_answer_button(
sender_id=user_id or None,
email=email or None,
bypass_filters=True,
is_bot_msg=False,
is_slash_command=False,
is_bot_dm=False,
),
slack_channel_config=slack_channel_config,

View File

@@ -28,7 +28,7 @@ logger_base = setup_logger()
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender_id:
if details.is_slash_command and details.sender_id:
respond_in_thread_or_channel(
client=client,
channel=details.channel_to_respond,
@@ -124,11 +124,11 @@ def handle_message(
messages = message_info.thread_messages
sender_id = message_info.sender_id
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_slash_command = message_info.is_slash_command
is_bot_dm = message_info.is_bot_dm
action = "slack_message"
if is_bot_msg:
if is_slash_command:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
@@ -197,7 +197,7 @@ def handle_message(
# If configured to respond to team members only, then cannot be used with a /OnyxBot command
# which would just respond to the sender
if send_to and is_bot_msg:
if send_to and is_slash_command:
if sender_id:
respond_in_thread_or_channel(
client=client,

View File

@@ -81,15 +81,15 @@ def handle_regular_answer(
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
is_slash_command = message_info.is_slash_command
# Capture whether response mode for channel is ephemeral. Even if the channel is set
# to respond with an ephemeral message, we still send as non-ephemeral if
# the message is a dm with the Onyx bot.
send_as_ephemeral = (
slack_channel_config.channel_config.get("is_ephemeral", False)
and not message_info.is_bot_dm
)
or message_info.is_slash_command
) and not message_info.is_bot_dm
# If the channel mis configured to respond with an ephemeral message,
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
@@ -164,7 +164,7 @@ def handle_regular_answer(
# in an attached document set were available to all users in the channel.)
bypass_acl = False
if not message_ts_to_respond_to and not is_bot_msg:
if not message_ts_to_respond_to and not is_slash_command:
# if the message is not "/onyx" command, then it should have a message ts to respond to
raise RuntimeError(
"No message timestamp to respond to in `handle_message`. This should never happen."
@@ -316,13 +316,14 @@ def handle_regular_answer(
return True
# Got an answer at this point, can remove reaction and give results
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if not is_slash_command: # Slash commands don't have reactions
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if answer.answer_valid is False:
logger.notice(

View File

@@ -876,12 +876,13 @@ def build_request_details(
sender_id=sender_id,
email=email,
bypass_filters=tagged,
is_bot_msg=False,
is_slash_command=False,
is_bot_dm=event.get("channel_type") == "im",
)
elif req.type == "slash_commands":
channel = req.payload["channel_id"]
channel_name = req.payload["channel_name"]
msg = req.payload["text"]
sender = req.payload["user_id"]
expert_info = expert_info_from_slack_id(
@@ -899,8 +900,8 @@ def build_request_details(
sender_id=sender,
email=email,
bypass_filters=True,
is_bot_msg=True,
is_bot_dm=False,
is_slash_command=True,
is_bot_dm=channel_name == "directmessage",
)
raise RuntimeError("Programming fault, this should never happen.")

View File

@@ -13,7 +13,7 @@ class SlackMessageInfo(BaseModel):
sender_id: str | None
email: str | None
bypass_filters: bool # User has tagged @OnyxBot
is_bot_msg: bool # User is using /OnyxBot
is_slash_command: bool # User is using /OnyxBot
is_bot_dm: bool # User is direct messaging to OnyxBot
@@ -25,7 +25,7 @@ class ActionValuesEphemeralMessageMessageInfo(BaseModel):
email: str | None
sender_id: str | None
thread_messages: list[ThreadMessage] | None
is_bot_msg: bool | None
is_slash_command: bool | None
is_bot_dm: bool | None
thread_to_respond: str | None

View File

@@ -3,7 +3,6 @@ import redis
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_pool import get_redis_client
@@ -31,11 +30,6 @@ class RedisConnector:
tenant_id, cc_pair_id, self.redis
)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
)
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
@@ -81,3 +75,11 @@ class RedisConnector:
object_id = parts[1]
return object_id
def db_lock_key(self, search_settings_id: int) -> str:
"""
Key for the db lock for an indexing attempt.
Prevents multiple modifications to the current indexing attempt row
from multiple docfetching/docprocessing tasks.
"""
return f"da_lock:indexing:db_{self.cc_pair_id}/{search_settings_id}"

View File

@@ -1,126 +1,10 @@
from datetime import datetime
from typing import cast
import redis
from pydantic import BaseModel
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
class RedisConnectorIndexPayload(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None
class RedisConnectorIndex:
"""Manages interactions with redis for indexing tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorindexing"
FENCE_PREFIX = f"{PREFIX}_fence" # "connectorindexing_fence"
GENERATOR_TASK_PREFIX = PREFIX + "+generator" # "connectorindexing+generator_fence"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorindexing_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorindexing_generator_complete
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
DB_LOCK_PREFIX = "da_lock:indexing:db"
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
TERMINATE_TTL = 600
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
# used to signal that the watchdog is running
WATCHDOG_PREFIX = PREFIX + "_watchdog"
WATCHDOG_TTL = 300
# used to signal that the connector itself is still running
CONNECTOR_ACTIVE_PREFIX = PREFIX + "_connector_active"
CONNECTOR_ACTIVE_TTL = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
def __init__(
self,
tenant_id: str,
cc_pair_id: int,
search_settings_id: int,
redis: redis.Redis,
) -> None:
self.tenant_id: str = tenant_id
self.cc_pair_id = cc_pair_id
self.search_settings_id = search_settings_id
self.redis = redis
self.generator_complete_key = (
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.filestore_lock_key = (
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.per_worker_lock_key = (
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
self.terminate_key = (
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generator_clear(self) -> None:
self.redis.delete(self.generator_complete_key)
def get_completion(self) -> int | None:
bytes = self.redis.get(self.generator_complete_key)
if bytes is None:
return None
status = int(cast(int, bytes))
return status
def reset(self) -> None:
self.redis.delete(self.filestore_lock_key)
self.redis.delete(self.db_lock_key)
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_complete_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
# leaving these temporarily for backwards compat, TODO: remove
for key in r.scan_iter(RedisConnectorIndex.CONNECTOR_ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -28,10 +28,7 @@ class RedisDocumentSet(RedisObjectHelper):
@property
def fenced(self) -> bool:
if self.redis.exists(self.fence_key):
return True
return False
return bool(self.redis.exists(self.fence_key))
def set_fence(self, payload: int | None) -> None:
if payload is None:

View File

@@ -1,6 +1,5 @@
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_usergroup import RedisUserGroup
@@ -16,8 +15,6 @@ def is_fence(key_bytes: bytes) -> bool:
return True
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
return True

View File

@@ -1,3 +1,4 @@
import io
import json
import mimetypes
import os
@@ -101,8 +102,9 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.file_processing.extract_file_text import convert_docx_to_txt
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.documents.models import AuthStatus
from onyx.server.documents.models import AuthUrl
@@ -124,6 +126,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -418,7 +421,29 @@ def extract_zip_metadata(zf: zipfile.ZipFile) -> dict[str, Any]:
return zip_metadata
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
def is_zip_file(file: UploadFile) -> bool:
"""
Check if the file is a zip file by content type or filename.
"""
return bool(
(
file.content_type
and file.content_type.startswith(
(
"application/zip",
"application/x-zip-compressed", # May be this in Windows
"application/x-zip",
"multipart/x-zip",
)
)
)
or (file.filename and file.filename.lower().endswith(".zip"))
)
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
@@ -429,12 +454,13 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
return not any(part.startswith(".") for part in normalized_path.split(os.sep))
deduped_file_paths = []
deduped_file_names = []
zip_metadata = {}
try:
file_store = get_default_file_store()
seen_zip = False
for file in files:
if file.content_type and file.content_type.startswith("application/zip"):
if is_zip_file(file):
if seen_zip:
raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL)
seen_zip = True
@@ -460,14 +486,24 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
file_type=mime_type,
)
deduped_file_paths.append(file_id)
deduped_file_names.append(os.path.basename(file_info))
continue
# Special handling for docx files - only store the plaintext version
if file.content_type and file.content_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
docx_file_id = convert_docx_to_txt(file, file_store)
deduped_file_paths.append(docx_file_id)
# For mypy, actual check happens at start of function
assert file.filename is not None
# Special handling for doc files - only store the plaintext version
file_type = mime_type_to_chat_file_type(file.content_type)
if file_type == ChatFileType.DOC:
extracted_text = extract_file_text(file.file, file.filename or "")
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=file_origin,
file_type="text/plain",
)
deduped_file_paths.append(text_file_id)
deduped_file_names.append(file.filename)
continue
# Default handling for all other file types
@@ -478,10 +514,15 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
file_type=file.content_type or "text/plain",
)
deduped_file_paths.append(file_id)
deduped_file_names.append(file.filename)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return FileUploadResponse(file_paths=deduped_file_paths, zip_metadata=zip_metadata)
return FileUploadResponse(
file_paths=deduped_file_paths,
file_names=deduped_file_names,
zip_metadata=zip_metadata,
)
@router.post("/admin/connector/file/upload")
@@ -489,7 +530,7 @@ def upload_files_api(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
) -> FileUploadResponse:
return upload_files(files)
return upload_files(files, FileOrigin.OTHER)
@router.get("/admin/connector")

View File

@@ -475,6 +475,7 @@ class GoogleServiceAccountCredentialRequest(BaseModel):
class FileUploadResponse(BaseModel):
file_paths: list[str]
file_names: list[str]
zip_metadata: dict[str, Any]

View File

@@ -179,15 +179,19 @@ def get_kg_entity_types(
) -> SourceAndEntityTypeView:
# when using for the first time, populate with default entity types
entity_types = {
key: [EntityType.from_model(et) for et in ets]
for key, ets in get_configured_entity_types(db_session=db_session).items()
source_name: [EntityType.from_model(et) for et in ets]
for source_name, ets in get_configured_entity_types(
db_session=db_session
).items()
}
source_statistics = {
key: SourceStatistics(
source_name=key, last_updated=last_updated, entities_count=entities_count
source_name: SourceStatistics(
source_name=source_name,
last_updated=last_updated,
entities_count=entities_count,
)
for key, (
for source_name, (
last_updated,
entities_count,
) in get_entity_stats_by_grounded_source_name(db_session=db_session).items()

View File

@@ -206,5 +206,5 @@ def create_deletion_attempt_for_connector_id(
if cc_pair.connector.source == DocumentSource.FILE:
connector = cc_pair.connector
file_store = get_default_file_store()
for file_name in connector.connector_specific_config.get("file_locations", []):
file_store.delete_file(file_name)
for file_id in connector.connector_specific_config.get("file_locations", []):
file_store.delete_file(file_id)

View File

@@ -1,6 +1,5 @@
import asyncio
import datetime
import io
import json
import os
import time
@@ -31,7 +30,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
@@ -63,9 +61,7 @@ from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.user_documents import create_user_files
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
@@ -717,105 +713,65 @@ def upload_files_for_chat(
):
raise HTTPException(
status_code=400,
detail="File size must be less than 20MB",
detail="Images must be less than 20MB",
)
file_store = get_default_file_store()
file_info: list[tuple[str, str | None, ChatFileType]] = []
for file in files:
file_type = mime_type_to_chat_file_type(file.content_type)
file_content = file.file.read() # Read the file content
# NOTE: Image conversion to JPEG used to be enforced here.
# This was removed to:
# 1. Preserve original file content for downloads
# 2. Maintain transparency in formats like PNG
# 3. Ameliorate issue with file conversion
file_content_io = io.BytesIO(file_content)
new_content_type = file.content_type
# Store the file normally
file_id = file_store.save_file(
content=file_content_io,
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=new_content_type or file_type.value,
# 5) Create a user file for each uploaded file
user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session)
for user_file in user_files:
# 6) Create connector
connector_base = ConnectorBase(
name=f"UserFile-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(
db_session=db_session,
connector_data=connector_base,
)
# 4) If the file is a doc, extract text and store that separately
if file_type == ChatFileType.DOC:
# Re-wrap bytes in a fresh BytesIO so we start at position 0
extracted_text_io = io.BytesIO(file_content)
extracted_text = extract_file_text(
file=extracted_text_io, # use the bytes we already read
file_name=file.filename or "",
)
# 7) Create credential
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type="text/plain",
)
# Return the text file as the "main" file descriptor for doc types
file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT))
else:
file_info.append((file_id, file.filename, file_type))
# 5) Create a user file for each uploaded file
user_files = create_user_files([file], RECENT_DOCS_FOLDER_ID, user, db_session)
for user_file in user_files:
# 6) Create connector
connector_base = ConnectorBase(
name=f"UserFile-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(
db_session=db_session,
connector_data=connector_base,
)
# 7) Create credential
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
# 8) Create connector credential pair
cc_pair = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
# 8) Create connector credential pair
cc_pair = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
return {
"files": [
{"id": file_id, "type": file_type, "name": file_name}
for file_id, file_name, file_type in file_info
{
"id": user_file.file_id,
"type": mime_type_to_chat_file_type(user_file.content_type),
"name": user_file.name,
}
for user_file in user_files
]
}

View File

@@ -408,6 +408,7 @@ def create_file_from_link(
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,

View File

@@ -44,12 +44,12 @@ litellm==1.72.2
lxml==5.3.0
lxml_html_clean==0.2.2
Mako==1.2.4
markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2
msal==1.28.0
nltk==3.9.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.75.0
openpyxl==3.0.10
passlib==1.7.4
playwright==1.41.2
psutil==5.9.5
@@ -66,7 +66,7 @@ pypdf==5.4.0
pytest-mock==3.12.0
pytest-playwright==0.7.0
python-docx==1.1.2
python-dotenv==1.0.0
python-dotenv==1.1.1
python-multipart==0.0.20
pywikibot==9.0.0
redis==5.0.8

View File

@@ -22,7 +22,6 @@ from onyx.configs.app_configs import REDIS_SSL
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_pool import RedisPool
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -130,9 +129,6 @@ def onyx_redis(
logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*"
purge_by_match_and_type(match_pattern, "string", batch, dry_run, r)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r
)

View File

@@ -187,7 +187,7 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
f"{connector_id} and Credential ID: {credential_id} does not exist."
)
file_names: list[str] = (
file_ids: list[str] = (
cc_pair.connector.connector_specific_config["file_locations"]
if cc_pair.connector.source == DocumentSource.FILE
else []
@@ -211,12 +211,12 @@ def _delete_connector(cc_pair_id: int, db_session: Session) -> None:
except Exception as e:
logger.error(f"Failed to delete connector due to {e}")
if file_names:
if file_ids:
logger.notice("Deleting stored files!")
file_store = get_default_file_store()
for file_name in file_names:
logger.notice(f"Deleting file {file_name}")
file_store.delete_file(file_name)
for file_id in file_ids:
logger.notice(f"Deleting file {file_id}")
file_store.delete_file(file_id)
db_session.commit()

View File

@@ -10,6 +10,8 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.connector import ConfluenceConnector
from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider
from onyx.db.models import ConnectorCredentialPair
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
from tests.daily.connectors.utils import load_all_docs_from_checkpoint_connector
@@ -101,7 +103,17 @@ def test_confluence_connector_restriction_handling(
mock_cc_pair.credential_id = 1
# Call the confluence_doc_sync function directly with the mock cc_pair
doc_access_generator = confluence_doc_sync(mock_cc_pair, lambda: [], None)
def mock_fetch_all_docs_fn(
sort_order: SortOrder | None = None,
) -> list[DocumentRow]:
return []
def mock_fetch_all_docs_ids_fn() -> list[str]:
return []
doc_access_generator = confluence_doc_sync(
mock_cc_pair, mock_fetch_all_docs_fn, mock_fetch_all_docs_ids_fn, None
)
doc_access_list = list(doc_access_generator)
assert len(doc_access_list) == 7
assert all(

View File

@@ -56,7 +56,9 @@ def test_single_text_file_with_metadata(
"onyx.connectors.file.connector.get_default_file_store",
return_value=mock_file_store,
):
connector = LocalFileConnector(file_locations=["test.txt"], zip_metadata={})
connector = LocalFileConnector(
file_locations=["test.txt"], file_names=["test.txt"], zip_metadata={}
)
batches = list(connector.load_from_state())
assert len(batches) == 1
@@ -113,7 +115,9 @@ def test_two_text_files_with_zip_metadata(
return_value=mock_file_store,
):
connector = LocalFileConnector(
file_locations=["file1.txt", "file2.txt"], zip_metadata=zip_metadata
file_locations=["file1.txt", "file2.txt"],
file_names=["file1.txt", "file2.txt"],
zip_metadata=zip_metadata,
)
batches = list(connector.load_from_state())

View File

@@ -10,6 +10,8 @@ from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync
from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from tests.daily.connectors.google_drive.consts_and_utils import ACCESS_MAPPING
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
@@ -71,7 +73,20 @@ def test_gdrive_perm_sync_with_real_data(
return_value=_build_connector(google_drive_service_acct_connector_factory),
):
# Call the function under test
doc_access_generator = gdrive_doc_sync(mock_cc_pair, lambda: [], mock_heartbeat)
def mock_fetch_all_docs_fn(
sort_order: SortOrder | None = None,
) -> list[DocumentRow]:
return []
def mock_fetch_all_docs_ids_fn() -> list[str]:
return []
doc_access_generator = gdrive_doc_sync(
mock_cc_pair,
mock_fetch_all_docs_fn,
mock_fetch_all_docs_ids_fn,
mock_heartbeat,
)
doc_access_list = list(doc_access_generator)
# Verify we got some results

View File

@@ -11,6 +11,7 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
def extract_key_value_pairs_to_set(
@@ -35,7 +36,7 @@ def _load_reference_data(
@pytest.fixture
def salesforce_connector() -> SalesforceConnector:
connector = SalesforceConnector(
requested_objects=["Account", "Contact", "Opportunity"],
requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact", "Opportunity"],
)
username = os.environ["SF_USERNAME"]

View File

@@ -85,12 +85,12 @@ def sharepoint_credentials() -> dict[str, str]:
}
def test_sharepoint_connector_all_sites(
def test_sharepoint_connector_all_sites__docs_only(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with no sites
connector = SharepointConnector()
connector = SharepointConnector(include_site_pages=False)
# Load credentials
connector.load_credentials(sharepoint_credentials)
@@ -135,12 +135,14 @@ def test_sharepoint_connector_specific_folder(
verify_document_content(doc, expected)
def test_sharepoint_connector_root_folder(
def test_sharepoint_connector_root_folder__docs_only(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(sites=[os.environ["SHAREPOINT_SITE"]])
connector = SharepointConnector(
sites=[os.environ["SHAREPOINT_SITE"]], include_site_pages=False
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
@@ -225,3 +227,59 @@ def test_sharepoint_connector_poll(
verify_document_content(
doc, [d for d in EXPECTED_DOCUMENTS if d.semantic_identifier == "test1.docx"][0]
)
def test_sharepoint_connector_pages(
mock_get_unstructured_api_key: MagicMock,
sharepoint_credentials: dict[str, str],
) -> None:
# Initialize connector with the base site URL
connector = SharepointConnector(
sites=["https://danswerai.sharepoint.com/sites/sharepoint-tests-pages"]
)
# Load credentials
connector.load_credentials(sharepoint_credentials)
# Get documents within the time window
document_batches = list(connector.load_from_state())
found_documents: list[Document] = [
doc for batch in document_batches for doc in batch
]
# Should only find CollabHome
assert len(found_documents) == 1, "Should only find one page"
doc = found_documents[0]
assert doc.semantic_identifier == "CollabHome"
verify_document_metadata(doc)
assert len(doc.sections) == 1
assert (
doc.sections[0].text
== """
# Home
Display recent news.
## News
Show recent activities from your site
## Site activity
## Quick links
Learn about a team site
Learn how to add a page
Add links to important documents and pages.
## Quick links
Documents
Add a document library
## Document library
""".strip()
)

View File

@@ -33,7 +33,7 @@ def _load_all_docs(
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
raise RuntimeError(f"Failed to load documents: {failure}")
if document is not None:
if document is not None and isinstance(document, Document):
documents.append(document)
if next_checkpoint is not None:
checkpoint = next_checkpoint
@@ -100,7 +100,7 @@ def load_everything_from_checkpoint_connector(
for document, failure, next_checkpoint in doc_batch_generator:
if failure is not None:
outputs.append(failure)
if document is not None:
if document is not None and isinstance(document, Document):
outputs.append(document)
if next_checkpoint is not None:
checkpoint = next_checkpoint

View File

@@ -34,7 +34,7 @@ class ConnectorManager:
connector_specific_config=(
connector_specific_config
or (
{"file_locations": [], "zip_metadata": {}}
{"file_locations": [], "file_names": [], "zip_metadata": {}}
if source == DocumentSource.FILE
else {}
)

View File

@@ -21,6 +21,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
FILE_NAME = "Sample.pdf"
FILE_PATH = "tests/integration/common_utils/test_files"
DOCX_FILE_NAME = "three_images.docx"
def test_image_indexing(
@@ -67,7 +68,11 @@ def test_image_indexing(
name=connector_name,
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={"file_locations": file_paths, "zip_metadata": {}},
connector_specific_config={
"file_locations": file_paths,
"file_names": [FILE_NAME],
"zip_metadata": {},
},
access_type=AccessType.PUBLIC,
groups=[],
user_performing_action=admin_user,
@@ -110,3 +115,112 @@ def test_image_indexing(
else:
assert document.image_file_id is not None
assert file_paths[0] in document.image_file_id
def test_docx_image_indexing(
reset: None,
admin_user: DATestUser,
vespa_client: vespa_fixture,
) -> None:
"""Test that images from docx files are correctly extracted and indexed."""
os.makedirs(FILE_PATH, exist_ok=True)
test_file_path = os.path.join(FILE_PATH, DOCX_FILE_NAME)
# Use FileManager to upload the test file
upload_response = FileManager.upload_file_for_connector(
file_path=test_file_path,
file_name=DOCX_FILE_NAME,
user_performing_action=admin_user,
)
LLMProviderManager.create(
name="test_llm_docx",
user_performing_action=admin_user,
)
SettingsManager.update_settings(
DATestSettings(
search_time_image_analysis_enabled=True,
image_extraction_and_analysis_enabled=True,
),
user_performing_action=admin_user,
)
file_paths = upload_response.file_paths
if not file_paths:
pytest.fail("File upload failed - no file paths returned")
# Create a dummy credential for the file connector
credential = CredentialManager.create(
source=DocumentSource.FILE,
credential_json={},
user_performing_action=admin_user,
)
# Create the connector
connector_name = f"DocxFileConnector-{int(datetime.now().timestamp())}"
connector = ConnectorManager.create(
name=connector_name,
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": file_paths,
"file_names": [DOCX_FILE_NAME],
"zip_metadata": {},
},
access_type=AccessType.PUBLIC,
groups=[],
user_performing_action=admin_user,
)
# Link the credential to the connector
cc_pair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.PUBLIC,
user_performing_action=admin_user,
)
# Explicitly run the connector to start indexing
CCPairManager.run_once(
cc_pair=cc_pair,
from_beginning=True,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=datetime.now(timezone.utc),
timeout=300,
user_performing_action=admin_user,
)
with get_session_with_current_tenant() as db_session:
# Fetch documents from Vespa - expect text content plus 3 images
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
# Should have documents for text content plus 3 images
assert (
len(documents) >= 3
), f"Expected at least 3 documents (3 images), got {len(documents)}"
# Count documents with images
image_documents = [doc for doc in documents if doc.image_file_id is not None]
text_documents = [doc for doc in documents if doc.image_file_id is None]
assert (
len(image_documents) == 3
), f"Expected exactly 3 image documents, got {len(image_documents)}"
assert (
len(text_documents) >= 1
), f"Expected at least 1 text document, got {len(text_documents)}"
# Verify each image document has a valid image_file_id pointing to our uploaded file
for image_doc in image_documents:
assert file_paths[0] in (
image_doc.image_file_id or ""
), f"Image document should reference uploaded file: {image_doc.image_file_id}"

View File

@@ -76,6 +76,7 @@ def test_zip_metadata_handling(
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": file_paths,
"file_names": [os.path.basename(file_path) for file_path in file_paths],
"zip_metadata": metadata,
},
access_type=AccessType.PUBLIC,

View File

@@ -65,7 +65,11 @@ def kg_test_docs() -> tuple[list[str], int, list[KGEntityType]]:
name="KG-Test-FileConnector",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={"file_locations": [], "zip_metadata": {}},
connector_specific_config={
"file_locations": [],
"file_names": [],
"zip_metadata": {},
},
user_performing_action=admin_user,
)
api_key = APIKeyManager.create(user_performing_action=admin_user)

View File

@@ -160,7 +160,11 @@ def create_connector(env_name: str, file_paths: list[str]) -> int:
name=connector_name,
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={"file_locations": file_paths, "zip_metadata": {}},
connector_specific_config={
"file_locations": file_paths,
"file_names": [], # For regression tests, no need for file_names
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,

View File

@@ -23,7 +23,7 @@ from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.github.connector import GithubConnector
from onyx.connectors.github.connector import GithubConnectorStage
from onyx.connectors.github.connector import SerializedRepository
from onyx.connectors.github.models import SerializedRepository
from onyx.connectors.models import Document
from tests.unit.onyx.connectors.utils import load_everything_from_checkpoint_connector
from tests.unit.onyx.connectors.utils import (
@@ -97,6 +97,10 @@ def create_mock_pr() -> Callable[..., MagicMock]:
else f"https://github.com/test-org/test-repo/pull/{number}"
)
mock_pr.raw_data = {}
mock_pr.base = MagicMock()
mock_pr.base.repo = MagicMock()
mock_pr.base.repo.full_name = "test-org/test-repo"
return mock_pr
return _create_mock_pr
@@ -121,6 +125,11 @@ def create_mock_issue() -> Callable[..., MagicMock]:
mock_issue.html_url = f"https://github.com/test-org/test-repo/issues/{number}"
mock_issue.pull_request = None # Not a PR
mock_issue.raw_data = {}
# Mock the nested base.repo.full_name attribute
mock_issue.repository = MagicMock()
mock_issue.repository.full_name = "test-org/test-repo"
return mock_issue
return _create_mock_issue
@@ -265,7 +274,7 @@ def test_load_from_checkpoint_with_rate_limit(
# Call load_from_checkpoint
end_time = time.time()
with patch(
"onyx.connectors.github.connector._sleep_after_rate_limit_exception"
"onyx.connectors.github.connector.sleep_after_rate_limit_exception"
) as mock_sleep:
outputs = load_everything_from_checkpoint_connector(
github_connector, 0, end_time
@@ -797,7 +806,7 @@ def test_load_from_checkpoint_cursor_pagination_completion(
mock_repo1.get_issues.return_value = mock_empty_issues_list
mock_repo2.get_issues.return_value = mock_empty_issues_list
with patch.object(
github_connector, "_get_all_repos", return_value=[mock_repo1, mock_repo2]
github_connector, "get_all_repos", return_value=[mock_repo1, mock_repo2]
), patch.object(
github_connector,
"_pull_requests_func",

View File

@@ -34,10 +34,16 @@ def mock_fetch_all_existing_docs_fn() -> MagicMock:
return MagicMock(return_value=[])
@pytest.fixture
def mock_fetch_all_existing_docs_ids_fn() -> MagicMock:
return MagicMock(return_value=[])
def test_jira_permission_sync(
jira_connector: JiraConnector,
mock_jira_cc_pair: MagicMock,
mock_fetch_all_existing_docs_fn: MagicMock,
mock_fetch_all_existing_docs_ids_fn: MagicMock,
) -> None:
with patch("onyx.connectors.jira.connector.build_jira_client") as mock_build_client:
mock_build_client.return_value = jira_connector._jira_client
@@ -45,5 +51,6 @@ def test_jira_permission_sync(
for doc in jira_doc_sync(
cc_pair=mock_jira_cc_pair,
fetch_all_existing_docs_fn=mock_fetch_all_existing_docs_fn,
fetch_all_existing_docs_ids_fn=mock_fetch_all_existing_docs_ids_fn,
):
print(doc)

View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
"""
Test script for the new custom query configuration functionality in SalesforceConnector.
This demonstrates how to use the new custom_query_config parameter to specify
exactly which fields and associations (child objects) to retrieve for each object type.
"""
import json
from typing import Any
from onyx.connectors.salesforce.connector import _validate_custom_query_config
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
def test_custom_query_config() -> None:
"""Test the custom query configuration functionality."""
# Example custom query configuration
# This specifies exactly which fields and associations to retrieve
custom_config = {
ACCOUNT_OBJECT_TYPE: {
"fields": ["Id", "Name", "Industry", "CreatedDate", MODIFIED_FIELD],
"associations": {
"Contact": ["Id", "FirstName", "LastName", "Email"],
"Opportunity": ["Id", "Name", "StageName", "Amount", "CloseDate"],
},
},
"Lead": {
"fields": ["Id", "FirstName", "LastName", "Company", "Status"],
"associations": {}, # No associations for Lead
},
}
# Create connector with custom configuration
connector = SalesforceConnector(
batch_size=50, custom_query_config=json.dumps(custom_config)
)
print("✅ SalesforceConnector created successfully with custom query config")
print(f"Parent object list: {connector.parent_object_list}")
print(f"Custom config keys: {list(custom_config.keys())}")
# Test that the parent object list is derived from the custom config
assert connector.parent_object_list == [ACCOUNT_OBJECT_TYPE, "Lead"]
assert connector.custom_query_config == custom_config
print("✅ Basic validation passed")
def test_traditional_config() -> None:
"""Test that the traditional requested_objects approach still works."""
# Traditional approach
connector = SalesforceConnector(
batch_size=50, requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact"]
)
print("✅ SalesforceConnector created successfully with traditional config")
print(f"Parent object list: {connector.parent_object_list}")
# Test that it still works the old way
assert connector.parent_object_list == [ACCOUNT_OBJECT_TYPE, "Contact"]
assert connector.custom_query_config is None
print("✅ Traditional config validation passed")
def test_validation() -> None:
"""Test that invalid configurations are rejected."""
# Test invalid config structure
invalid_configs: list[Any] = [
# Invalid fields type
{ACCOUNT_OBJECT_TYPE: {"fields": "invalid"}},
# Invalid associations type
{ACCOUNT_OBJECT_TYPE: {"associations": "invalid"}},
# Nested invalid structure
{ACCOUNT_OBJECT_TYPE: {"associations": {"Contact": {"fields": "invalid"}}}},
]
for i, invalid_config in enumerate(invalid_configs):
try:
_validate_custom_query_config(invalid_config)
assert False, f"Should have raised ValueError for invalid_config[{i}]"
except ValueError:
print(f"✅ Correctly rejected invalid config {i}")
if __name__ == "__main__":
print("Testing SalesforceConnector custom query configuration...")
print("=" * 60)
test_custom_query_config()
print()
test_traditional_config()
print()
test_validation()
print()
print("=" * 60)
print("🎉 All tests passed! The custom query configuration is working correctly.")
print()
print("Example usage:")
print(
"""
# Custom configuration approach
custom_config = {
ACCOUNT_OBJECT_TYPE: {
"fields": ["Id", "Name", "Industry"],
"associations": {
"Contact": {
"fields": ["Id", "FirstName", "LastName", "Email"],
"associations": {}
}
}
}
}
connector = SalesforceConnector(custom_query_config=custom_config)
# Traditional approach (still works)
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact"])
"""
)

View File

@@ -26,6 +26,9 @@ from onyx.connectors.salesforce.salesforce_calls import _make_time_filter_for_sf
from onyx.connectors.salesforce.salesforce_calls import _make_time_filtered_query
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.utils.logger import setup_logger
# from onyx.connectors.salesforce.onyx_salesforce_type import OnyxSalesforceType
@@ -153,7 +156,7 @@ def _create_csv_file_and_update_db(
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
object_type: The Salesforce object type (e.g. ACCOUNT_OBJECT_TYPE, "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
@@ -184,7 +187,7 @@ def _create_csv_with_example_data(sf_db: OnyxSalesforceSQLite) -> None:
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
ACCOUNT_OBJECT_TYPE: [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
@@ -428,7 +431,7 @@ def _test_query(sf_db: OnyxSalesforceSQLite) -> None:
}
# Get all Account IDs
account_ids = sf_db.find_ids_by_type("Account")
account_ids = sf_db.find_ids_by_type(ACCOUNT_OBJECT_TYPE)
# Verify we found all expected accounts
assert len(account_ids) == len(
@@ -480,7 +483,9 @@ def _test_upsert(sf_db: OnyxSalesforceSQLite) -> None:
},
]
_create_csv_file_and_update_db(sf_db, "Account", update_data, "update_data.csv")
_create_csv_file_and_update_db(
sf_db, ACCOUNT_OBJECT_TYPE, update_data, "update_data.csv"
)
# Verify the update worked
updated_record = sf_db.get_record(_VALID_SALESFORCE_IDS[0])
@@ -573,7 +578,7 @@ def _test_account_with_children(sf_db: OnyxSalesforceSQLite) -> None:
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = sf_db.find_ids_by_type("Account")
account_ids = sf_db.find_ids_by_type(ACCOUNT_OBJECT_TYPE)
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
@@ -690,7 +695,7 @@ def _test_get_affected_parent_ids(sf_db: OnyxSalesforceSQLite) -> None:
"""
# Create test data with relationships
test_data = {
"Account": [
ACCOUNT_OBJECT_TYPE: [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
@@ -720,40 +725,46 @@ def _test_get_affected_parent_ids(sf_db: OnyxSalesforceSQLite) -> None:
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = [_VALID_SALESFORCE_IDS[1]] # Parent Account 2
parent_types = set(["Account"])
parent_types = set([ACCOUNT_OBJECT_TYPE])
affected_ids_by_type = defaultdict(set)
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
updated_ids, parent_types
):
affected_ids_by_type[parent_type].add(parent_id)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[1] in affected_ids_by_type["Account"]
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
), "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[1] in affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
), "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact
parent_types = set(["Account"])
parent_types = set([ACCOUNT_OBJECT_TYPE])
affected_ids_by_type = defaultdict(set)
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
updated_ids, parent_types
):
affected_ids_by_type[parent_type].add(parent_id)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids_by_type["Account"]
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
), "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = [_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]] # Both cases
parent_types = set(["Account"])
parent_types = set([ACCOUNT_OBJECT_TYPE])
affected_ids_by_type = defaultdict(set)
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
updated_ids, parent_types
):
affected_ids_by_type[parent_type].add(parent_id)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
affected_ids = affected_ids_by_type["Account"]
assert (
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
), "Account type not in affected_ids_by_type"
affected_ids = affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
@@ -929,7 +940,7 @@ def _get_child_records_by_id_query(
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> str:
"""Returns a SOQL query given the object id, type and child relationships.
@@ -963,7 +974,7 @@ def test_salesforce_connector_single() -> None:
# this record has some opportunity child records
parent_id = "001bm00000BXfhEAAT"
parent_type = "Account"
parent_type = ACCOUNT_OBJECT_TYPE
parent_types = [parent_type]
username = os.environ["SF_USERNAME"]
@@ -987,11 +998,11 @@ def test_salesforce_connector_single() -> None:
child_to_parent_types: dict[str, set[str]] = (
{}
) # reverse map from child to parent types
child_relationship_to_queryable_fields: dict[str, list[str]] = {}
child_relationship_to_queryable_fields: dict[str, set[str]] = {}
# parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
# Step 1 - make a list of all the types to download (parent + direct child + "User")
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
for parent_type_working in parent_types:
child_types_working = sf_client.get_children_of_sf_type(parent_type_working)
@@ -1035,8 +1046,8 @@ def test_salesforce_connector_single() -> None:
result = sf_client.query(query)
records = result["records"]
record = records[0]
assert record["attributes"]["type"] == "Account"
parent_last_modified_date = record.get("LastModifiedDate", "")
assert record["attributes"]["type"] == ACCOUNT_OBJECT_TYPE
parent_last_modified_date = record.get(MODIFIED_FIELD, "")
parent_semantic_identifier = record.get("Name", "Unknown Object")
parent_last_modified_by_id = record.get("LastModifiedById")
@@ -1163,9 +1174,9 @@ def test_salesforce_connector_single() -> None:
# get user relationship if present
primary_owner_list = None
if parent_last_modified_by_id:
queryable_user_fields = sf_client.get_queryable_fields_by_type("User")
queryable_user_fields = sf_client.get_queryable_fields_by_type(USER_OBJECT_TYPE)
query = get_object_by_id_query(
parent_last_modified_by_id, "User", queryable_user_fields
parent_last_modified_by_id, USER_OBJECT_TYPE, queryable_user_fields
)
result = sf_client.query(query)
user_record = result["records"][0]

View File

@@ -4,7 +4,7 @@ dependencies:
version: 14.3.1
- name: vespa
repository: https://onyx-dot-app.github.io/vespa-helm-charts
version: 0.2.23
version: 0.2.24
- name: nginx
repository: oci://registry-1.docker.io/bitnamicharts
version: 15.14.0
@@ -14,5 +14,5 @@ dependencies:
- name: minio
repository: oci://registry-1.docker.io/bitnamicharts
version: 17.0.4
digest: sha256:4c938cf9138e4ff6f5ecac5c044324d508ef2b0e1a23ba3f2bc089015cb40ff6
generated: "2025-06-16T18:53:19.63168-07:00"
digest: sha256:dddd687525764f5698adc339a11d268b0ee9c3ca81f8d46c9e65a6bf2c21cf25
generated: "2025-08-06T19:00:41.218513-07:00"

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.2.2
version: 0.2.5
appVersion: latest
annotations:
category: Productivity
@@ -23,7 +23,7 @@ dependencies:
repository: https://charts.bitnami.com/bitnami
condition: postgresql.enabled
- name: vespa
version: 0.2.23
version: 0.2.24
repository: https://onyx-dot-app.github.io/vespa-helm-charts
condition: vespa.enabled
- name: nginx

View File

@@ -123,15 +123,15 @@ function ActionForm({
<button
type="button"
className="
absolute
bottom-4
absolute
bottom-4
right-4
border-border
border
bg-background
rounded
py-1
px-3
py-1
px-3
text-sm
hover:bg-accent-background
"
@@ -162,7 +162,7 @@ function ActionForm({
/>
<div className="mt-4 text-sm bg-blue-50 text-blue-700 dark:text-blue-300 dark:bg-blue-900 p-4 rounded-md border border-blue-200 dark:border-blue-800">
<Link
href="https://docs.onyx.app/tools/custom"
href="https://docs.onyx.app/actions/custom#custom-actions"
className="text-link hover:underline flex items-center"
target="_blank"
rel="noopener noreferrer"

Some files were not shown because too many files have changed in this diff Show More