Compare commits

...

21 Commits

Author SHA1 Message Date
Evan Lohn
4d768e03d7 add fields to error temporarily 2025-08-08 11:37:15 -07:00
Evan Lohn
b89643f28c better logs and new attempt 2025-08-08 09:58:48 -07:00
Evan Lohn
555630070b more sf logs 2025-08-07 19:18:37 -07:00
Evan Lohn
1d16c96009 fix: sf connector docs 2025-08-07 19:18:37 -07:00
Evan Lohn
297720c132 refactor: file processing (#5136)
* file processing refactor

* mypy

* CW comments

* address CW
2025-08-08 00:34:35 +00:00
Evan Lohn
bd4bd00cef feat: office parsing markitdown (#5115)
* switch to markitdown untested

* passing tests

* reset file

* dotenv version

* docs

* add test file

* add doc

* fix integration test
2025-08-07 23:26:02 +00:00
Chris Weaver
07c482f727 Make starter messages visible on smaller screens (#5170) 2025-08-07 16:49:18 -07:00
Wenxi
cf193dee29 feat: support gpt5 models (#5169)
* support gpt5 models

* gpt5mini visible
2025-08-07 12:35:46 -07:00
Evan Lohn
1b47fa2700 fix: remove erroneous error case and add valid error (#5163)
* fix: remove erroneous error case and add valid error

* also address docfetching-docprocessing limbo
2025-08-07 18:17:00 +00:00
Wenxi Onyx
e1a305d18a mask llm api key from logs 2025-08-07 00:01:29 -07:00
Evan Lohn
e2233d22c9 feat: salesforce custom query (#5158)
* WIP merged approach untested

* tested custom configs

* JT comments

* fix unit test

* CW comments

* fix unit test
2025-08-07 02:37:23 +00:00
Justin Tahara
20d1175312 feat(infra): Bump Vespa Helm Version (#5161)
* feat(infra): Bump Vespa Helm Version

* Adding the Chart.lock file
2025-08-06 19:06:18 -07:00
justin-tahara
7117774287 Revert that change. Let's do this properly 2025-08-06 18:54:21 -07:00
justin-tahara
77f2660bb2 feat(infra): Update Vespa Helm Chart Version 2025-08-06 18:53:02 -07:00
Wenxi
1b2f4f3b87 fix: slash command slackbot to respond in private msg (#5151)
* fix slash command slackbot to respond in private msg

* rename confusing variable. fix slash message response in DMs
2025-08-05 19:03:38 -07:00
Evan Lohn
d85b55a9d2 no more scheduled stalling (#5154) 2025-08-05 20:17:44 +00:00
Justin Tahara
e2bae5a2d9 fix(infra): Adding helm directory (#5156)
* feat(infra): Adding helm directory

* one more fix
2025-08-05 14:11:57 -07:00
Justin Tahara
cc9c76c4fb feat(infra): Release Charts on Github Pages (#5155) 2025-08-05 14:03:28 -07:00
Chris Weaver
258e08abcd feat: add customization via env vars for curator role (#5150)
* Add customization via env vars for curator role

* Simplify

* Simplify more

* Address comments
2025-08-05 09:58:36 -07:00
Evan Lohn
67047e42a7 fix: preserve error traces (#5152) 2025-08-05 09:44:55 -07:00
SubashMohan
146628e734 fix unsupported character error in minio migration (#5145)
* fix unsupported character error in minio migration

* slash fix
2025-08-04 12:42:07 -07:00
56 changed files with 1074 additions and 583 deletions

View File

@@ -36,5 +36,7 @@ jobs:
- name: Run chart-releaser
uses: helm/chart-releaser-action@v1.7.0
with:
charts_dir: deployment/helm/charts
env:
CR_TOKEN: "${{ secrets.GITHUB_TOKEN }}"

View File

@@ -206,7 +206,7 @@ def _handle_standard_answers(
restate_question_blocks = get_restate_blocks(
msg=query_msg.message,
is_bot_msg=message_info.is_bot_msg,
is_slash_command=message_info.is_slash_command,
)
answer_blocks = build_standard_answer_blocks(

View File

@@ -67,7 +67,7 @@ def generate_chat_messages_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)
@@ -99,7 +99,7 @@ def generate_user_report(
file_id = file_store.save_file(
content=temp_file,
display_name=file_name,
file_origin=FileOrigin.OTHER,
file_origin=FileOrigin.GENERATED_REPORT,
file_type="text/csv",
)

View File

@@ -231,10 +231,7 @@ class DynamicTenantScheduler(PersistentScheduler):
True if equivalent, False if not."""
current_tasks = set(name for name, _ in schedule1)
new_tasks = set(schedule2.keys())
if current_tasks != new_tasks:
return False
return True
return current_tasks == new_tasks
@beat_init.connect

View File

@@ -32,7 +32,6 @@ from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_document_set import RedisDocumentSet
@@ -161,7 +160,6 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
RedisUserGroup.reset_all(r)
RedisConnectorDelete.reset_all(r)
RedisConnectorPrune.reset_all(r)
RedisConnectorIndex.reset_all(r)
RedisConnectorStop.reset_all(r)
RedisConnectorPermissionSync.reset_all(r)
RedisConnectorExternalGroupSync.reset_all(r)

View File

@@ -1,3 +1,5 @@
from collections.abc import Generator
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from pathlib import Path
@@ -8,10 +10,12 @@ import httpx
from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE
from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT
from onyx.connectors.connector_runner import batched_doc_ids
from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import BaseConnector
from onyx.connectors.interfaces import CheckpointedConnector
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SlimConnector
@@ -22,12 +26,14 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
PRUNING_CHECKPOINTED_BATCH_SIZE = 32
def document_batch_to_ids(
doc_batch: list[Document],
) -> set[str]:
return {doc.id for doc in doc_batch}
doc_batch: Iterator[list[Document]],
) -> Generator[set[str], None, None]:
for doc_list in doc_batch:
yield {doc.id for doc in doc_list}
def extract_ids_from_runnable_connector(
@@ -46,33 +52,50 @@ def extract_ids_from_runnable_connector(
for metadata_batch in runnable_connector.retrieve_all_slim_documents():
all_connector_doc_ids.update({doc.id for doc in metadata_batch})
doc_batch_generator = None
doc_batch_id_generator = None
if isinstance(runnable_connector, LoadConnector):
doc_batch_generator = runnable_connector.load_from_state()
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.load_from_state()
)
elif isinstance(runnable_connector, PollConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
doc_batch_generator = runnable_connector.poll_source(start=start, end=end)
doc_batch_id_generator = document_batch_to_ids(
runnable_connector.poll_source(start=start, end=end)
)
elif isinstance(runnable_connector, CheckpointedConnector):
start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp()
end = datetime.now(timezone.utc).timestamp()
checkpoint = runnable_connector.build_dummy_checkpoint()
checkpoint_generator = runnable_connector.load_from_checkpoint(
start=start, end=end, checkpoint=checkpoint
)
doc_batch_id_generator = batched_doc_ids(
checkpoint_generator, batch_size=PRUNING_CHECKPOINTED_BATCH_SIZE
)
else:
raise RuntimeError("Pruning job could not find a valid runnable_connector.")
doc_batch_processing_func = document_batch_to_ids
# this function is called per batch for rate limiting
def doc_batch_processing_func(doc_batch_ids: set[str]) -> set[str]:
return doc_batch_ids
if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE:
doc_batch_processing_func = rate_limit_builder(
max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60
)(document_batch_to_ids)
for doc_batch in doc_batch_generator:
)(lambda x: x)
for doc_batch_ids in doc_batch_id_generator:
if callback:
if callback.should_stop():
raise RuntimeError(
"extract_ids_from_runnable_connector: Stop signal detected"
)
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch))
all_connector_doc_ids.update(doc_batch_processing_func(doc_batch_ids))
if callback:
callback.progress("extract_ids_from_runnable_connector", len(doc_batch))
callback.progress("extract_ids_from_runnable_connector", len(doc_batch_ids))
return all_connector_doc_ids

View File

@@ -193,12 +193,7 @@ def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | N
task_logger.info(
"Timed out waiting for tasks blocking deletion. Resetting blocking fences."
)
search_settings_list = get_all_search_settings(db_session)
for search_settings in search_settings_list:
redis_connector_index = redis_connector.new_index(
search_settings.id
)
redis_connector_index.reset()
redis_connector.prune.reset()
redis_connector.permissions.reset()
redis_connector.external_group_sync.reset()

View File

@@ -2,7 +2,6 @@ import multiprocessing
import os
import time
import traceback
from http import HTTPStatus
from time import sleep
import sentry_sdk
@@ -22,7 +21,7 @@ from onyx.background.celery.tasks.models import SimpleJobResult
from onyx.background.indexing.job_client import SimpleJob
from onyx.background.indexing.job_client import SimpleJobClient
from onyx.background.indexing.job_client import SimpleJobException
from onyx.background.indexing.run_docfetching import run_indexing_entrypoint
from onyx.background.indexing.run_docfetching import run_docfetching_entrypoint
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.connectors.exceptions import ConnectorValidationError
@@ -34,7 +33,6 @@ from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import SENTRY_DSN
@@ -156,7 +154,6 @@ def _docfetching_task(
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector.new_index(search_settings_id)
# TODO: remove all fences, cause all signals to be set in postgres
if redis_connector.delete.fenced:
@@ -214,7 +211,7 @@ def _docfetching_task(
)
# This is where the heavy/real work happens
run_indexing_entrypoint(
run_docfetching_entrypoint(
app,
index_attempt_id,
tenant_id,
@@ -261,7 +258,7 @@ def _docfetching_task(
def process_job_result(
job: SimpleJob,
connector_source: str | None,
redis_connector_index: RedisConnectorIndex,
index_attempt_id: int,
log_builder: ConnectorIndexingLogBuilder,
) -> SimpleJobResult:
result = SimpleJobResult()
@@ -278,13 +275,11 @@ def process_job_result(
# In EKS, there is an edge case where successful tasks return exit
# code 1 in the cloud due to the set_spawn_method not sticking.
# We've since worked around this, but the following is a safe way to
# work around this issue. Basically, we ignore the job error state
# if the completion signal is OK.
status_int = redis_connector_index.get_completion()
if status_int:
status_enum = HTTPStatus(status_int)
if status_enum == HTTPStatus.OK:
# Workaround: check that the total number of batches is set, since this only
# happens when docfetching completed successfully
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt and index_attempt.total_batches is not None:
ignore_exitcode = True
if ignore_exitcode:
@@ -458,9 +453,6 @@ def docfetching_proxy_task(
)
)
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
@@ -487,7 +479,7 @@ def docfetching_proxy_task(
if job.done():
try:
result = process_job_result(
job, result.connector_source, redis_connector_index, log_builder
job, result.connector_source, index_attempt_id, log_builder
)
except Exception:
task_logger.exception(

View File

@@ -4,7 +4,6 @@ from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from http import HTTPStatus
from typing import Any
from celery import shared_task
@@ -16,6 +15,8 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat
@@ -66,6 +67,7 @@ from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.index_attempt import mark_attempt_partially_succeeded
from onyx.db.index_attempt import mark_attempt_succeeded
from onyx.db.indexing_coordination import CoordinationStatus
from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS
from onyx.db.indexing_coordination import IndexingCoordination
from onyx.db.models import IndexAttempt
from onyx.db.search_settings import get_active_search_settings_list
@@ -102,6 +104,7 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
USER_FILE_INDEXING_LIMIT = 100
DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER = 4
def _get_fence_validation_block_expiration() -> int:
@@ -257,7 +260,7 @@ class ConnectorIndexingLogBuilder:
def monitor_indexing_attempt_progress(
attempt: IndexAttempt, tenant_id: str, db_session: Session
attempt: IndexAttempt, tenant_id: str, db_session: Session, task: Task
) -> None:
"""
TODO: rewrite this docstring
@@ -316,7 +319,9 @@ def monitor_indexing_attempt_progress(
# Check task completion using Celery
try:
check_indexing_completion(attempt.id, coordination_status, storage, tenant_id)
check_indexing_completion(
attempt.id, coordination_status, storage, tenant_id, task
)
except Exception as e:
logger.exception(
f"Failed to monitor document processing completion: "
@@ -350,6 +355,7 @@ def check_indexing_completion(
coordination_status: CoordinationStatus,
storage: DocumentBatchStorage,
tenant_id: str,
task: Task,
) -> None:
logger.info(
@@ -376,20 +382,78 @@ def check_indexing_completion(
# Update progress tracking and check for stalls
with get_session_with_current_tenant() as db_session:
# Update progress tracking
stalled_timeout_hours = INDEXING_PROGRESS_TIMEOUT_HOURS
# Index attempts that are waiting between docfetching and
# docprocessing get a generous stalling timeout
if batches_total is not None and batches_processed == 0:
stalled_timeout_hours = (
stalled_timeout_hours * DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER
)
timed_out = not IndexingCoordination.update_progress_tracking(
db_session, index_attempt_id, batches_processed
db_session,
index_attempt_id,
batches_processed,
timeout_hours=stalled_timeout_hours,
)
# Check for stalls (3-6 hour timeout)
if timed_out:
logger.error(
f"Indexing attempt {index_attempt_id} has been indexing for 3-6 hours without progress. "
f"Marking it as failed."
)
mark_attempt_failed(
index_attempt_id, db_session, failure_reason="Stalled indexing"
)
# Check for stalls (3-6 hour timeout). Only applies to in-progress attempts.
attempt = get_index_attempt(db_session, index_attempt_id)
if attempt and timed_out:
if attempt.status == IndexingStatus.IN_PROGRESS:
logger.error(
f"Indexing attempt {index_attempt_id} has been indexing for "
f"{stalled_timeout_hours//2}-{stalled_timeout_hours} hours without progress. "
f"Marking it as failed."
)
mark_attempt_failed(
index_attempt_id, db_session, failure_reason="Stalled indexing"
)
elif (
attempt.status == IndexingStatus.NOT_STARTED and attempt.celery_task_id
):
# Check if the task exists in the celery queue
# This handles the case where Redis dies after task creation but before task execution
redis_celery = task.app.broker_connection().channel().client # type: ignore
task_exists = celery_find_task(
attempt.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
redis_celery,
)
unacked_task_ids = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, redis_celery
)
if not task_exists and attempt.celery_task_id not in unacked_task_ids:
# there is a race condition where the docfetching task has been taken off
# the queues (i.e. started) but the indexing attempt still has a status of
# Not Started because the switch to in progress takes like 0.1 seconds.
# sleep a bit and confirm that the attempt is still not in progress.
time.sleep(1)
attempt = get_index_attempt(db_session, index_attempt_id)
if attempt and attempt.status == IndexingStatus.NOT_STARTED:
logger.error(
f"Task {attempt.celery_task_id} attached to indexing attempt "
f"{index_attempt_id} does not exist in the queue. "
f"Marking indexing attempt as failed."
)
mark_attempt_failed(
index_attempt_id,
db_session,
failure_reason="Task not in queue",
)
else:
logger.info(
f"Indexing attempt {index_attempt_id} is {attempt.status}. 3-6 hours without heartbeat "
"but task is in the queue. Likely underprovisioned docfetching worker."
)
# Update last progress time so we won't time out again for another 3 hours
IndexingCoordination.update_progress_tracking(
db_session,
index_attempt_id,
batches_processed,
force_update_progress=True,
)
# check again on the next check_for_indexing task
# TODO: on the cloud this is currently 25 minutes at most, which
@@ -449,15 +513,6 @@ def check_indexing_completion(
db_session=db_session,
)
# TODO: make it so we don't need this (might already be true)
redis_connector = RedisConnector(
tenant_id, attempt.connector_credential_pair_id
)
redis_connector_index = redis_connector.new_index(
attempt.search_settings_id
)
redis_connector_index.set_generator_complete(HTTPStatus.OK.value)
# Clean up FileStore storage (still needed for document batches during transition)
try:
logger.info(f"Cleaning up storage after indexing completion: {storage}")
@@ -811,7 +866,9 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
for attempt in active_attempts:
try:
monitor_indexing_attempt_progress(attempt, tenant_id, db_session)
monitor_indexing_attempt_progress(
attempt, tenant_id, db_session, self
)
except Exception:
task_logger.exception(f"Error monitoring attempt {attempt.id}")
@@ -1085,12 +1142,8 @@ def _docprocessing_task(
f"Index attempt {index_attempt_id} is not running, status {index_attempt.status}"
)
redis_connector_index = redis_connector.new_index(
index_attempt.search_settings.id
)
cross_batch_db_lock: RedisLock = r.lock(
redis_connector_index.db_lock_key,
redis_connector.db_lock_key(index_attempt.search_settings.id),
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
thread_local=False,
)
@@ -1230,17 +1283,6 @@ def _docprocessing_task(
f"attempt={index_attempt_id} "
)
# on failure, signal completion with an error to unblock the watchdog
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(db_session, index_attempt_id)
if index_attempt and index_attempt.search_settings:
redis_connector_index = redis_connector.new_index(
index_attempt.search_settings.id
)
redis_connector_index.set_generator_complete(
HTTPStatus.INTERNAL_SERVER_ERROR.value
)
raise
finally:
if per_batch_lock and per_batch_lock.owned():

View File

@@ -47,7 +47,6 @@ from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import SyncStatus
from onyx.db.enums import SyncType
from onyx.db.models import ConnectorCredentialPair
from onyx.db.search_settings import get_current_search_settings
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.tag import delete_orphan_tags__no_commit
@@ -519,9 +518,6 @@ def connector_pruning_generator_task(
cc_pair.credential,
)
search_settings = get_current_search_settings(db_session)
redis_connector.new_index(search_settings.id)
callback = PruneCallback(
0,
redis_connector,

View File

@@ -226,8 +226,12 @@ def _check_connector_and_attempt_status(
raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled")
if index_attempt_loop.status != IndexingStatus.IN_PROGRESS:
error_str = ""
if index_attempt_loop.error_msg:
error_str = f" Original error: {index_attempt_loop.error_msg}"
raise RuntimeError(
f"Index Attempt is not running, status is {index_attempt_loop.status}"
f"Index Attempt is not running, status is {index_attempt_loop.status}.{error_str}"
)
if index_attempt_loop.celery_task_id is None:
@@ -832,7 +836,7 @@ def _run_indexing(
)
def run_indexing_entrypoint(
def run_docfetching_entrypoint(
app: Celery,
index_attempt_id: int,
tenant_id: str,
@@ -1350,6 +1354,9 @@ def reissue_old_batches(
)
path_info = batch_storage.extract_path_info(batch_id)
if path_info is None:
logger.warning(
f"Could not extract path info from batch {batch_id}, skipping"
)
continue
if path_info.cc_pair_id != cc_pair_id:
raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}")

View File

@@ -359,6 +359,12 @@ POLL_CONNECTOR_OFFSET = 30 # Minutes overlap between poll windows
# only very select connectors are enabled and admins cannot add other connector types
ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or ""
# If set to true, curators can only access and edit assistants that they created
CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS = (
os.environ.get("CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS", "").lower()
== "true"
)
# Some calls to get information on expert users are quite costly especially with rate limiting
# Since experts are not used in the actual user experience, currently it is turned off
# for some connectors

View File

@@ -25,6 +25,28 @@ TimeRange = tuple[datetime, datetime]
CT = TypeVar("CT", bound=ConnectorCheckpoint)
def batched_doc_ids(
checkpoint_connector_generator: CheckpointOutput[CT],
batch_size: int,
) -> Generator[set[str], None, None]:
batch: set[str] = set()
for document, failure, next_checkpoint in CheckpointOutputWrapper[CT]()(
checkpoint_connector_generator
):
if document is not None:
batch.add(document.id)
elif (
failure and failure.failed_document and failure.failed_document.document_id
):
batch.add(failure.failed_document.document_id)
if len(batch) >= batch_size:
yield batch
batch = set()
if len(batch) > 0:
yield batch
class CheckpointOutputWrapper(Generic[CT]):
"""
Wraps a CheckpointOutput generator to give things back in a more digestible format,

View File

@@ -1,5 +1,6 @@
import csv
import gc
import json
import os
import sys
import tempfile
@@ -28,8 +29,12 @@ from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import BASE_DATA_PATH
from onyx.connectors.salesforce.utils import get_sqlite_db_path
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -38,27 +43,27 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
_DEFAULT_PARENT_OBJECT_TYPES = ["Account"]
_DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE]
_DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = {
"Opportunity": {
"Account": "account",
ACCOUNT_OBJECT_TYPE: "account",
"FiscalQuarter": "fiscal_quarter",
"FiscalYear": "fiscal_year",
"IsClosed": "is_closed",
"Name": "name",
NAME_FIELD: "name",
"StageName": "stage_name",
"Type": "type",
"Amount": "amount",
"CloseDate": "close_date",
"Probability": "probability",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
MODIFIED_FIELD: "last_modified_date",
},
"Contact": {
"Account": "account",
ACCOUNT_OBJECT_TYPE: "account",
"CreatedDate": "created_date",
"LastModifiedDate": "last_modified_date",
MODIFIED_FIELD: "last_modified_date",
},
}
@@ -74,19 +79,77 @@ class SalesforceConnectorContext:
parent_to_child_types: dict[str, set[str]] = {} # map from parent to child types
child_to_parent_types: dict[str, set[str]] = {} # map from child to parent types
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
type_to_queryable_fields: dict[str, list[str]] = {}
type_to_queryable_fields: dict[str, set[str]] = {}
prefix_to_type: dict[str, str] = {} # infer the object type of an id immediately
parent_to_child_relationships: dict[str, set[str]] = (
{}
) # map from parent to child relationships
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = (
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = (
{}
) # map from relationship to queryable fields
parent_child_names_to_relationships: dict[str, str] = {}
def _extract_fields_and_associations_from_config(
config: dict[str, Any], object_type: str
) -> tuple[list[str] | None, dict[str, list[str]]]:
"""
Extract fields and associations for a specific object type from custom config.
Returns:
tuple of (fields_list, associations_dict)
- fields_list: List of fields to query, or None if not specified (use all)
- associations_dict: Dict mapping association names to their config
"""
if object_type not in config:
return None, {}
obj_config = config[object_type]
fields = obj_config.get("fields")
associations = obj_config.get("associations", {})
return fields, associations
def _validate_custom_query_config(config: dict[str, Any]) -> None:
"""
Validate the structure of the custom query configuration.
"""
for object_type, obj_config in config.items():
if not isinstance(obj_config, dict):
raise ValueError(
f"top level object {object_type} must be mapped to a dictionary"
)
# Check if fields is a list when present
if "fields" in obj_config:
if not isinstance(obj_config["fields"], list):
raise ValueError("if fields key exists, value must be a list")
for v in obj_config["fields"]:
if not isinstance(v, str):
raise ValueError(f"if fields list value {v} is not a string")
# Check if associations is a dict when present
if "associations" in obj_config:
if not isinstance(obj_config["associations"], dict):
raise ValueError(
"if associations key exists, value must be a dictionary"
)
for assoc_name, assoc_fields in obj_config["associations"].items():
if not isinstance(assoc_fields, list):
raise ValueError(
f"associations list value {assoc_fields} for key {assoc_name} is not a list"
)
for v in assoc_fields:
if not isinstance(v, str):
raise ValueError(
f"if associations list value {v} is not a string"
)
class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
"""Approach outline
@@ -134,14 +197,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
batch_size: int = INDEX_BATCH_SIZE,
requested_objects: list[str] = [],
custom_query_config: str | None = None,
) -> None:
self.batch_size = batch_size
self._sf_client: OnyxSalesforce | None = None
self.parent_object_list = (
[obj.capitalize() for obj in requested_objects]
if requested_objects
else _DEFAULT_PARENT_OBJECT_TYPES
)
# Validate and store custom query config
if custom_query_config:
config_json = json.loads(custom_query_config)
self.custom_query_config: dict[str, Any] | None = config_json
# If custom query config is provided, use the object types from it
self.parent_object_list = list(config_json.keys())
else:
self.custom_query_config = None
# Use the traditional requested_objects approach
self.parent_object_list = (
[obj.strip().capitalize() for obj in requested_objects]
if requested_objects
else _DEFAULT_PARENT_OBJECT_TYPES
)
def load_credentials(
self,
@@ -187,7 +261,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
@staticmethod
def _download_object_csvs(
all_types_to_filter: dict[str, bool],
queryable_fields_by_type: dict[str, list[str]],
queryable_fields_by_type: dict[str, set[str]],
directory: str,
sf_client: OnyxSalesforce,
start: SecondsSinceUnixEpoch | None = None,
@@ -325,9 +399,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# all_types.update(child_types.keys())
# # Always want to make sure user is grabbed for permissioning purposes
# all_types.add("User")
# all_types.add(USER_OBJECT_TYPE)
# # Always want to make sure account is grabbed for reference purposes
# all_types.add("Account")
# all_types.add(ACCOUNT_OBJECT_TYPE)
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -351,7 +425,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# all_types.update(child_types)
# # Always want to make sure user is grabbed for permissioning purposes
# all_types.add("User")
# all_types.add(USER_OBJECT_TYPE)
# logger.info(f"All object types: num={len(all_types)} list={all_types}")
@@ -364,7 +438,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateDocumentsOutput:
type_to_processed: dict[str, int] = {}
logger.info("_fetch_from_salesforce starting.")
logger.info("_fetch_from_salesforce starting (full sync).")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
@@ -548,7 +622,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
) -> GenerateDocumentsOutput:
type_to_processed: dict[str, int] = {}
logger.info("_fetch_from_salesforce starting.")
logger.info("_fetch_from_salesforce starting (delta sync).")
if not self._sf_client:
raise RuntimeError("self._sf_client is None!")
@@ -677,7 +751,9 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
try:
last_modified_by_id = record["LastModifiedById"]
user_record = self.sf_client.query_object(
"User", last_modified_by_id, ctx.type_to_queryable_fields
USER_OBJECT_TYPE,
last_modified_by_id,
ctx.type_to_queryable_fields,
)
if user_record:
primary_owner = BasicExpertInfo.from_dict(user_record)
@@ -792,7 +868,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = (
{}
) # for a given object, the fields reference parent objects
type_to_queryable_fields: dict[str, list[str]] = {}
type_to_queryable_fields: dict[str, set[str]] = {}
prefix_to_type: dict[str, str] = {}
parent_to_child_relationships: dict[str, set[str]] = (
@@ -802,15 +878,13 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# relationship keys are formatted as "parent__relationship"
# we have to do this because relationship names are not unique!
# values are a dict of relationship names to a list of queryable fields
parent_to_relationship_queryable_fields: dict[str, dict[str, list[str]]] = {}
parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = {}
parent_child_names_to_relationships: dict[str, str] = {}
full_sync = False
if start is None and end is None:
full_sync = True
full_sync = start is None and end is None
# Step 1 - make a list of all the types to download (parent + direct child + "User")
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
# prefixes = {}
global_description = sf_client.describe()
@@ -831,16 +905,62 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
for parent_type in parent_types:
# parent_onyx_sf_type = OnyxSalesforceType(parent_type, sf_client)
type_to_queryable_fields[parent_type] = (
sf_client.get_queryable_fields_by_type(parent_type)
)
child_types_working = sf_client.get_children_of_sf_type(parent_type)
logger.debug(
f"Found {len(child_types_working)} child types for {parent_type}"
)
custom_fields: list[str] | None = []
associations_config: dict[str, list[str]] | None = None
# parent_to_child_relationships[parent_type] = child_types_working
# Set queryable fields for parent type
if self.custom_query_config:
custom_fields, associations_config = (
_extract_fields_and_associations_from_config(
self.custom_query_config, parent_type
)
)
custom_fields = custom_fields or []
# Get custom fields for parent type
field_set = set(custom_fields)
# these are expected and used during doc conversion
field_set.add(NAME_FIELD)
field_set.add(MODIFIED_FIELD)
# Use only the specified fields
type_to_queryable_fields[parent_type] = field_set
logger.info(f"Using custom fields for {parent_type}: {field_set}")
else:
# Use all queryable fields
type_to_queryable_fields[parent_type] = (
sf_client.get_queryable_fields_by_type(parent_type)
)
logger.info(f"Using all fields for {parent_type}")
child_types_all = sf_client.get_children_of_sf_type(parent_type)
logger.debug(f"Found {len(child_types_all)} child types for {parent_type}")
logger.debug(f"child types: {child_types_all}")
child_types_working = child_types_all.copy()
if associations_config is not None:
child_types_working = {
k: v for k, v in child_types_all.items() if v in associations_config
}
any_not_found = False
for k in associations_config:
if k not in child_types_working:
any_not_found = True
logger.warning(f"Association {k} not found in {parent_type}")
if any_not_found:
queryable_fields = sf_client.get_queryable_fields_by_type(
parent_type
)
raise RuntimeError(
f"Associations {associations_config} not found in {parent_type} "
f"with child objects {child_types_all}"
f" and fields {queryable_fields}"
)
parent_to_child_relationships[parent_type] = set()
parent_to_child_types[parent_type] = set()
parent_to_relationship_queryable_fields[parent_type] = {}
for child_type, child_relationship in child_types_working.items():
child_type = cast(str, child_type)
@@ -848,8 +968,6 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
# map parent name to child name
if parent_type not in parent_to_child_types:
parent_to_child_types[parent_type] = set()
parent_to_child_types[parent_type].add(child_type)
# reverse map child name to parent name
@@ -858,19 +976,25 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
child_to_parent_types[child_type].add(parent_type)
# map parent name to child relationship
if parent_type not in parent_to_child_relationships:
parent_to_child_relationships[parent_type] = set()
parent_to_child_relationships[parent_type].add(child_relationship)
# map relationship to queryable fields of the target table
queryable_fields = sf_client.get_queryable_fields_by_type(child_type)
if config_fields := (
associations_config and associations_config.get(child_type)
):
field_set = set(config_fields)
# these are expected and used during doc conversion
field_set.add(NAME_FIELD)
field_set.add(MODIFIED_FIELD)
queryable_fields = field_set
else:
queryable_fields = sf_client.get_queryable_fields_by_type(
child_type
)
if child_relationship in parent_to_relationship_queryable_fields:
raise RuntimeError(f"{child_relationship=} already exists")
if parent_type not in parent_to_relationship_queryable_fields:
parent_to_relationship_queryable_fields[parent_type] = {}
parent_to_relationship_queryable_fields[parent_type][
child_relationship
] = queryable_fields
@@ -894,14 +1018,22 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
all_types.update(child_types)
# NOTE(rkuo): should this be an implicit parent type?
all_types.add("User") # Always add User for permissioning purposes
all_types.add("Account") # Always add Account for reference purposes
all_types.add(USER_OBJECT_TYPE) # Always add User for permissioning purposes
all_types.add(ACCOUNT_OBJECT_TYPE) # Always add Account for reference purposes
logger.info(f"All object types: num={len(all_types)} list={all_types}")
# Ensure User and Account have queryable fields if they weren't already processed
essential_types = [USER_OBJECT_TYPE, ACCOUNT_OBJECT_TYPE]
for essential_type in essential_types:
if essential_type not in type_to_queryable_fields:
type_to_queryable_fields[essential_type] = (
sf_client.get_queryable_fields_by_type(essential_type)
)
# 1.1 - Detect all fields in child types which reference a parent type.
# build dicts to detect relationships between parent and child
for child_type in child_types:
for child_type in child_types.union(essential_types):
# onyx_sf_type = OnyxSalesforceType(child_type, sf_client)
parent_reference_fields = sf_client.get_parent_reference_fields(
child_type, parent_types
@@ -1003,6 +1135,32 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
yield doc_metadata_list
def validate_connector_settings(self) -> None:
"""
Validate that the Salesforce credentials and connector settings are correct.
Specifically checks that we can make an authenticated request to Salesforce.
"""
try:
# Attempt to fetch a small batch of objects (arbitrary endpoint) to verify credentials
self.sf_client.describe()
except Exception as e:
raise ConnectorMissingCredentialError(
"Failed to validate Salesforce credentials. Please check your"
f"credentials and try again. Error: {e}"
)
if self.custom_query_config:
try:
_validate_custom_query_config(self.custom_query_config)
except Exception as e:
raise ConnectorMissingCredentialError(
"Failed to validate Salesforce custom query config. Please check your"
f"config and try again. Error: {e}"
)
logger.info("Salesforce credentials validated successfully.")
# @override
# def load_from_checkpoint(
# self,
@@ -1032,7 +1190,7 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
if __name__ == "__main__":
connector = SalesforceConnector(requested_objects=["Account"])
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE])
connector.load_credentials(
{

View File

@@ -10,6 +10,8 @@ from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.utils.logger import setup_logger
@@ -140,7 +142,7 @@ def _extract_primary_owner(
first_name=user_data.get("FirstName"),
last_name=user_data.get("LastName"),
email=user_data.get("Email"),
display_name=user_data.get("Name"),
display_name=user_data.get(NAME_FIELD),
)
# Check if all fields are None
@@ -166,8 +168,8 @@ def convert_sf_query_result_to_doc(
"""Generates a yieldable Document from query results"""
base_url = f"https://{sf_client.sf_instance}"
extracted_doc_updated_at = time_str_to_utc(record["LastModifiedDate"])
extracted_semantic_identifier = record.get("Name", "Unknown Object")
extracted_doc_updated_at = time_str_to_utc(record[MODIFIED_FIELD])
extracted_semantic_identifier = record.get(NAME_FIELD, "Unknown Object")
sections = [_extract_section(record, f"{base_url}/{record_id}")]
for child_record_key, child_record in child_records.items():
@@ -205,8 +207,8 @@ def convert_sf_object_to_doc(
salesforce_id = object_dict["Id"]
onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}"
base_url = f"https://{sf_instance}"
extracted_doc_updated_at = time_str_to_utc(object_dict["LastModifiedDate"])
extracted_semantic_identifier = object_dict.get("Name", "Unknown Object")
extracted_doc_updated_at = time_str_to_utc(object_dict[MODIFIED_FIELD])
extracted_semantic_identifier = object_dict.get(NAME_FIELD, "Unknown Object")
sections = [_extract_section(sf_object.data, f"{base_url}/{sf_object.id}")]
for id in sf_db.get_child_ids(sf_object.id):

View File

@@ -60,7 +60,7 @@ class OnyxSalesforce(Salesforce):
return True
for suffix in SALESFORCE_BLACKLISTED_SUFFIXES:
if object_type_lower.endswith(prefix):
if object_type_lower.endswith(suffix):
return True
return False
@@ -112,7 +112,7 @@ class OnyxSalesforce(Salesforce):
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> str:
"""Returns a SOQL query given the object id, type and child relationships.
@@ -148,7 +148,7 @@ class OnyxSalesforce(Salesforce):
self,
object_type: str,
object_id: str,
type_to_queryable_fields: dict[str, list[str]],
type_to_queryable_fields: dict[str, set[str]],
) -> dict[str, Any] | None:
record: dict[str, Any] = {}
@@ -172,7 +172,7 @@ class OnyxSalesforce(Salesforce):
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> dict[str, dict[str, Any]]:
"""There's a limit on the number of subqueries we can put in a single query."""
child_records: dict[str, dict[str, Any]] = {}
@@ -264,10 +264,10 @@ class OnyxSalesforce(Salesforce):
time.sleep(3)
raise
def get_queryable_fields_by_type(self, name: str) -> list[str]:
def get_queryable_fields_by_type(self, name: str) -> set[str]:
object_description = self.describe_type(name)
if object_description is None:
return []
return set()
fields: list[dict[str, Any]] = object_description["fields"]
valid_fields: set[str] = set()
@@ -286,7 +286,7 @@ class OnyxSalesforce(Salesforce):
if field_name:
valid_fields.add(field_name)
return list(valid_fields - field_names_to_remove)
return valid_fields - field_names_to_remove
def get_children_of_sf_type(self, sf_type: str) -> dict[str, str]:
"""Returns a dict of child object names to relationship names.

View File

@@ -14,6 +14,7 @@ from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
rate_limit_builder,
)
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -54,12 +55,12 @@ def _build_created_date_time_filter_for_salesforce(
def _make_time_filter_for_sf_type(
queryable_fields: list[str],
queryable_fields: set[str],
start: SecondsSinceUnixEpoch,
end: SecondsSinceUnixEpoch,
) -> str | None:
if "LastModifiedDate" in queryable_fields:
if MODIFIED_FIELD in queryable_fields:
return _build_last_modified_time_filter_for_salesforce(start, end)
if "CreatedDate" in queryable_fields:
@@ -69,14 +70,14 @@ def _make_time_filter_for_sf_type(
def _make_time_filtered_query(
queryable_fields: list[str], sf_type: str, time_filter: str
queryable_fields: set[str], sf_type: str, time_filter: str
) -> str:
query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}"
return query
def get_object_by_id_query(
object_id: str, sf_type: str, queryable_fields: list[str]
object_id: str, sf_type: str, queryable_fields: set[str]
) -> str:
query = (
f"SELECT {', '.join(queryable_fields)} FROM {sf_type} WHERE Id = '{object_id}'"
@@ -193,7 +194,7 @@ def _bulk_retrieve_from_salesforce(
def fetch_all_csvs_in_parallel(
sf_client: Salesforce,
all_types_to_filter: dict[str, bool],
queryable_fields_by_type: dict[str, list[str]],
queryable_fields_by_type: dict[str, set[str]],
start: SecondsSinceUnixEpoch | None,
end: SecondsSinceUnixEpoch | None,
target_dir: str,

View File

@@ -8,11 +8,15 @@ from pathlib import Path
from typing import Any
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import NAME_FIELD
from onyx.connectors.salesforce.utils import SalesforceObject
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.connectors.salesforce.utils import validate_salesforce_id
from onyx.utils.logger import setup_logger
from shared_configs.utils import batch_list
logger = setup_logger()
@@ -567,7 +571,7 @@ class OnyxSalesforceSQLite:
uncommitted_rows = 0
# If we're updating User objects, update the email map
if object_type == "User":
if object_type == USER_OBJECT_TYPE:
OnyxSalesforceSQLite._update_user_email_map(cursor)
return updated_ids
@@ -619,7 +623,7 @@ class OnyxSalesforceSQLite:
with self._conn:
cursor = self._conn.cursor()
# Get the object data and account data
if object_type == "Account" or isChild:
if object_type == ACCOUNT_OBJECT_TYPE or isChild:
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?", (object_id,)
)
@@ -638,7 +642,7 @@ class OnyxSalesforceSQLite:
data = json.loads(result[0][0])
if object_type != "Account":
if object_type != ACCOUNT_OBJECT_TYPE:
# convert any account ids of the relationships back into data fields, with name
for row in result:
@@ -647,14 +651,14 @@ class OnyxSalesforceSQLite:
if len(row) < 3:
continue
if row[1] and row[2] and row[2] == "Account":
if row[1] and row[2] and row[2] == ACCOUNT_OBJECT_TYPE:
data["AccountId"] = row[1]
cursor.execute(
"SELECT data FROM salesforce_objects WHERE id = ?",
(row[1],),
)
account_data = json.loads(cursor.fetchone()[0])
data["Account"] = account_data.get("Name", "")
data[ACCOUNT_OBJECT_TYPE] = account_data.get(NAME_FIELD, "")
return SalesforceObject(id=object_id, type=object_type, data=data)

View File

@@ -2,6 +2,11 @@ import os
from dataclasses import dataclass
from typing import Any
NAME_FIELD = "Name"
MODIFIED_FIELD = "LastModifiedDate"
ACCOUNT_OBJECT_TYPE = "Account"
USER_OBJECT_TYPE = "User"
@dataclass
class SalesforceObject:

View File

@@ -267,6 +267,7 @@ class IndexingCoordination:
index_attempt_id: int,
current_batches_completed: int,
timeout_hours: int = INDEXING_PROGRESS_TIMEOUT_HOURS,
force_update_progress: bool = False,
) -> bool:
"""
Update progress tracking for stall detection.
@@ -281,7 +282,8 @@ class IndexingCoordination:
current_time = get_db_current_time(db_session)
# No progress - check if this is the first time tracking
if attempt.last_progress_time is None:
# or if the caller wants to simulate guaranteed progress
if attempt.last_progress_time is None or force_update_progress:
# First time tracking - initialize
attempt.last_progress_time = current_time
attempt.last_batches_completed_count = current_batches_completed

View File

@@ -15,6 +15,7 @@ from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
from onyx.configs.app_configs import DISABLE_AUTH
from onyx.configs.chat_configs import BING_API_KEY
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
@@ -96,6 +97,14 @@ def _add_user_filters(
where_clause = Persona.is_public == True # noqa: E712
return stmt.where(where_clause)
# If curator ownership restriction is enabled, curators can only access their own assistants
if CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS and user.role in [
UserRole.CURATOR,
UserRole.GLOBAL_CURATOR,
]:
where_clause = (Persona.user_id == user.id) | (Persona.user_id.is_(None))
return stmt.where(where_clause)
where_clause = User__UserGroup.user_id == user.id
if user.role == UserRole.CURATOR and get_editable:
where_clause &= User__UserGroup.is_curator == True # noqa: E712

View File

@@ -17,11 +17,11 @@ from typing import NamedTuple
from zipfile import BadZipFile
import chardet
import docx # type: ignore
import openpyxl # type: ignore
import pptx # type: ignore
from docx import Document as DocxDocument
from fastapi import UploadFile
from markitdown import FileConversionException
from markitdown import MarkItDown
from markitdown import UnsupportedFormatException
from PIL import Image
from pypdf import PdfReader
from pypdf.errors import PdfStreamError
@@ -83,11 +83,6 @@ IMAGE_MEDIA_TYPES = [
"image/webp",
]
KNOWN_OPENPYXL_BUGS = [
"Value must be either numerical or a string containing a wildcard",
"File contains no valid workbook part",
]
class OnyxExtensionType(IntFlag):
Plain = auto()
@@ -149,6 +144,13 @@ def is_macos_resource_fork_file(file_name: str) -> bool:
)
def to_bytesio(stream: IO[bytes]) -> BytesIO:
if isinstance(stream, BytesIO):
return stream
data = stream.read() # consumes the stream!
return BytesIO(data)
def load_files_from_zip(
zip_file_io: IO,
ignore_macos_resource_fork_files: bool = True,
@@ -305,19 +307,38 @@ def read_pdf_file(
return "", metadata, []
def extract_docx_images(docx_bytes: IO[Any]) -> list[tuple[bytes, str]]:
"""
Given the bytes of a docx file, extract all the images.
Returns a list of tuples (image_bytes, image_name).
"""
out = []
try:
with zipfile.ZipFile(docx_bytes) as z:
for name in z.namelist():
if name.startswith("word/media/"):
out.append((z.read(name), name.split("/")[-1]))
except Exception:
logger.exception("Failed to extract all docx images")
return out
def docx_to_text_and_images(
file: IO[Any], file_name: str = ""
) -> tuple[str, Sequence[tuple[bytes, str]]]:
"""
Extract text from a docx. If embed_images=True, also extract inline images.
Extract text from a docx.
Return (text_content, list_of_images).
"""
paragraphs = []
embedded_images: list[tuple[bytes, str]] = []
md = MarkItDown(enable_plugins=False)
try:
doc = docx.Document(file)
except (BadZipFile, ValueError) as e:
doc = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
logger.warning(
f"Failed to extract docx {file_name or 'docx file'}: {e}. Attempting to read as text file."
)
@@ -330,96 +351,44 @@ def docx_to_text_and_images(
)
return text_content_raw or "", []
# Grab text from paragraphs
for paragraph in doc.paragraphs:
paragraphs.append(paragraph.text)
# Reset position so we can re-load the doc (python-docx has read the stream)
# Note: if python-docx has fully consumed the stream, you may need to open it again from memory.
# For large docs, a more robust approach is needed.
# This is a simplified example.
for rel_id, rel in doc.part.rels.items():
if "image" in rel.reltype:
# Skip images that are linked rather than embedded (TargetMode="External")
if getattr(rel, "is_external", False):
continue
try:
# image is typically in rel.target_part.blob
image_bytes = rel.target_part.blob
except ValueError:
# Safeguard against relationships that lack an internal target_part
# (e.g., external relationships or other anomalies)
continue
image_name = rel.target_part.partname
# store
embedded_images.append((image_bytes, os.path.basename(str(image_name))))
text_content = "\n".join(paragraphs)
return text_content, embedded_images
file.seek(0)
return doc.markdown, extract_docx_images(to_bytesio(file))
def pptx_to_text(file: IO[Any], file_name: str = "") -> str:
md = MarkItDown(enable_plugins=False)
try:
presentation = pptx.Presentation(file)
except BadZipFile as e:
presentation = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
error_str = f"Failed to extract text from {file_name or 'pptx file'}: {e}"
logger.warning(error_str)
return ""
text_content = []
for slide_number, slide in enumerate(presentation.slides, start=1):
slide_text = f"\nSlide {slide_number}:\n"
for shape in slide.shapes:
if hasattr(shape, "text"):
slide_text += shape.text + "\n"
text_content.append(slide_text)
return TEXT_SECTION_SEPARATOR.join(text_content)
return presentation.markdown
def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
md = MarkItDown(enable_plugins=False)
try:
workbook = openpyxl.load_workbook(file, read_only=True)
except BadZipFile as e:
workbook = md.convert(to_bytesio(file))
except (
BadZipFile,
ValueError,
FileConversionException,
UnsupportedFormatException,
) as e:
error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}"
if file_name.startswith("~"):
logger.debug(error_str + " (this is expected for files with ~)")
else:
logger.warning(error_str)
return ""
except Exception as e:
if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS):
logger.error(
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return ""
raise e
text_content = []
for sheet in workbook.worksheets:
rows = []
num_empty_consecutive_rows = 0
for row in sheet.iter_rows(min_row=1, values_only=True):
row_str = ",".join(str(cell or "") for cell in row)
# Only add the row if there are any values in the cells
if len(row_str) >= len(row):
rows.append(row_str)
num_empty_consecutive_rows = 0
else:
num_empty_consecutive_rows += 1
if num_empty_consecutive_rows > 100:
# handle massive excel sheets with mostly empty cells
logger.warning(
f"Found {num_empty_consecutive_rows} empty rows in {file_name},"
" skipping rest of file"
)
break
sheet_str = "\n".join(rows)
text_content.append(sheet_str)
return TEXT_SECTION_SEPARATOR.join(text_content)
return workbook.markdown
def eml_to_text(file: IO[Any]) -> str:
@@ -472,9 +441,9 @@ def extract_file_text(
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
".docx": lambda f: docx_to_text_and_images(f)[0], # no images
".pptx": pptx_to_text,
".xlsx": xlsx_to_text,
".docx": lambda f: docx_to_text_and_images(f, file_name)[0], # no images
".pptx": lambda f: pptx_to_text(f, file_name),
".xlsx": lambda f: xlsx_to_text(f, file_name),
".eml": eml_to_text,
".epub": epub_to_text,
".html": parse_html_page_basic,
@@ -553,7 +522,7 @@ def extract_text_and_images(
# docx example for embedded images
if extension == ".docx":
text_content, images = docx_to_text_and_images(file)
text_content, images = docx_to_text_and_images(file, file_name)
return ExtractionResult(
text_content=text_content, embedded_images=images, metadata={}
)

View File

@@ -32,9 +32,11 @@ def is_valid_image_type(mime_type: str) -> bool:
Returns:
True if the MIME type is a valid image type, False otherwise
"""
if not mime_type:
return False
return mime_type.startswith("image/") and mime_type not in EXCLUDED_IMAGE_TYPES
return (
bool(mime_type)
and mime_type.startswith("image/")
and mime_type not in EXCLUDED_IMAGE_TYPES
)
def is_supported_by_vision_llm(mime_type: str) -> bool:

View File

@@ -1,6 +1,7 @@
import re
from copy import copy
from dataclasses import dataclass
from io import BytesIO
from typing import IO
import bs4
@@ -161,7 +162,7 @@ def format_document_soup(
return strip_excessive_newlines_and_spaces(text)
def parse_html_page_basic(text: str | IO[bytes]) -> str:
def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str:
soup = bs4.BeautifulSoup(text, "html.parser")
return format_document_soup(soup)

View File

@@ -196,6 +196,9 @@ class FileStoreDocumentBatchStorage(DocumentBatchStorage):
for batch_file_name in batch_names:
path_info = self.extract_path_info(batch_file_name)
if path_info is None:
logger.warning(
f"Could not extract path info from batch file: {batch_file_name}"
)
continue
new_batch_file_name = self._get_batch_file_name(path_info.batch_num)
self.file_store.change_file_id(batch_file_name, new_batch_file_name)

View File

@@ -49,11 +49,10 @@ def sanitize_s3_key_name(file_name: str) -> str:
# Characters to avoid completely (replace with underscore)
# These are characters that AWS recommends avoiding
avoid_chars = r'[\\{}^%`\[\]"<>#|~]'
avoid_chars = r'[\\{}^%`\[\]"<>#|~/]'
# Replace avoided characters with underscore
sanitized = re.sub(avoid_chars, "_", file_name)
# Characters that might require special handling but are allowed
# We'll URL encode these to be safe
special_chars = r"[&$@=;:+,?\s]"
@@ -81,6 +80,9 @@ def sanitize_s3_key_name(file_name: str) -> str:
# Remove any trailing periods to avoid download issues
sanitized = sanitized.rstrip(".")
# Remove multiple separators
sanitized = re.sub(r"[-_]{2,}", "-", sanitized)
# If sanitization resulted in empty string, use a default
if not sanitized:
sanitized = "sanitized_file"

View File

@@ -46,7 +46,6 @@ def store_user_file_plaintext(user_file_id: int, plaintext_content: str) -> bool
# Get plaintext file name
plaintext_file_name = user_file_id_to_plaintext_file_name(user_file_id)
# Use a separate session to avoid committing the caller's transaction
try:
file_store = get_default_file_store()
file_content = BytesIO(plaintext_content.encode("utf-8"))

View File

@@ -867,31 +867,27 @@ def index_doc_batch(
user_file_id_to_raw_text: dict[int, str] = {}
for document_id in updatable_ids:
# Only calculate token counts for documents that have a user file ID
if (
document_id in doc_id_to_user_file_id
and doc_id_to_user_file_id[document_id] is not None
):
user_file_id = doc_id_to_user_file_id[document_id]
if not user_file_id:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content))
if llm_tokenizer
else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
user_file_id = doc_id_to_user_file_id.get(document_id)
if user_file_id is None:
continue
document_chunks = [
chunk
for chunk in chunks_with_embeddings
if chunk.source_document.id == document_id
]
if document_chunks:
combined_content = " ".join(
[chunk.content for chunk in document_chunks]
)
token_count = (
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
)
user_file_id_to_token_count[user_file_id] = token_count
user_file_id_to_raw_text[user_file_id] = combined_content
else:
user_file_id_to_token_count[user_file_id] = None
# we're concerned about race conditions where multiple simultaneous indexings might result
# in one set of metadata overwriting another one in vespa.

View File

@@ -313,14 +313,14 @@ class DefaultMultiLLM(LLM):
self._model_kwargs = model_kwargs
def log_model_configs(self) -> None:
logger.debug(f"Config: {self.config}")
def _safe_model_config(self) -> dict:
dump = self.config.model_dump()
dump["api_key"] = mask_string(dump.get("api_key", ""))
return dump
def log_model_configs(self) -> None:
logger.debug(f"Config: {self._safe_model_config()}")
def _record_call(self, prompt: LanguageModelInput) -> None:
if self._long_term_logger:
self._long_term_logger.record(
@@ -397,7 +397,11 @@ class DefaultMultiLLM(LLM):
# streaming choice
stream=stream,
# model params
temperature=self._temperature,
temperature=(
1
if self.config.model_name in ["gpt-5", "gpt-5-mini", "gpt-5-nano"]
else self._temperature
),
timeout=timeout_override or self._timeout,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified

View File

@@ -47,6 +47,9 @@ class WellKnownLLMProviderDescriptor(BaseModel):
OPENAI_PROVIDER_NAME = "openai"
OPEN_AI_MODEL_NAMES = [
"gpt-5",
"gpt-5-mini",
"gpt-5-nano",
"o4-mini",
"o3-mini",
"o1-mini",
@@ -73,7 +76,14 @@ OPEN_AI_MODEL_NAMES = [
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
]
OPEN_AI_VISIBLE_MODEL_NAMES = ["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"]
OPEN_AI_VISIBLE_MODEL_NAMES = [
"gpt-5",
"gpt-5-mini",
"o1",
"o3-mini",
"gpt-4o",
"gpt-4o-mini",
]
BEDROCK_PROVIDER_NAME = "bedrock"
# need to remove all the weird "bedrock/eu-central-1/anthropic.claude-v1" named

View File

@@ -151,7 +151,7 @@ def _build_ephemeral_publication_block(
email=message_info.email,
sender_id=message_info.sender_id,
thread_messages=[],
is_bot_msg=message_info.is_bot_msg,
is_slash_command=message_info.is_slash_command,
is_bot_dm=message_info.is_bot_dm,
thread_to_respond=respond_ts,
)
@@ -225,10 +225,10 @@ def _build_doc_feedback_block(
def get_restate_blocks(
msg: str,
is_bot_msg: bool,
is_slash_command: bool,
) -> list[Block]:
# Only the slash command needs this context because the user doesn't see their own input
if not is_bot_msg:
if not is_slash_command:
return []
return [
@@ -576,7 +576,7 @@ def build_slack_response_blocks(
# If called with the OnyxBot slash command, the question is lost so we have to reshow it
if not skip_restated_question:
restate_question_block = get_restate_blocks(
message_info.thread_messages[-1].message, message_info.is_bot_msg
message_info.thread_messages[-1].message, message_info.is_slash_command
)
else:
restate_question_block = []

View File

@@ -177,7 +177,7 @@ def handle_generate_answer_button(
sender_id=user_id or None,
email=email or None,
bypass_filters=True,
is_bot_msg=False,
is_slash_command=False,
is_bot_dm=False,
),
slack_channel_config=slack_channel_config,

View File

@@ -28,7 +28,7 @@ logger_base = setup_logger()
def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None:
if details.is_bot_msg and details.sender_id:
if details.is_slash_command and details.sender_id:
respond_in_thread_or_channel(
client=client,
channel=details.channel_to_respond,
@@ -124,11 +124,11 @@ def handle_message(
messages = message_info.thread_messages
sender_id = message_info.sender_id
bypass_filters = message_info.bypass_filters
is_bot_msg = message_info.is_bot_msg
is_slash_command = message_info.is_slash_command
is_bot_dm = message_info.is_bot_dm
action = "slack_message"
if is_bot_msg:
if is_slash_command:
action = "slack_slash_message"
elif bypass_filters:
action = "slack_tag_message"
@@ -197,7 +197,7 @@ def handle_message(
# If configured to respond to team members only, then cannot be used with a /OnyxBot command
# which would just respond to the sender
if send_to and is_bot_msg:
if send_to and is_slash_command:
if sender_id:
respond_in_thread_or_channel(
client=client,

View File

@@ -81,15 +81,15 @@ def handle_regular_answer(
messages = message_info.thread_messages
message_ts_to_respond_to = message_info.msg_to_respond
is_bot_msg = message_info.is_bot_msg
is_slash_command = message_info.is_slash_command
# Capture whether response mode for channel is ephemeral. Even if the channel is set
# to respond with an ephemeral message, we still send as non-ephemeral if
# the message is a dm with the Onyx bot.
send_as_ephemeral = (
slack_channel_config.channel_config.get("is_ephemeral", False)
and not message_info.is_bot_dm
)
or message_info.is_slash_command
) and not message_info.is_bot_dm
# If the channel mis configured to respond with an ephemeral message,
# or the message is a dm to the Onyx bot, we should use the proper onyx user from the email.
@@ -164,7 +164,7 @@ def handle_regular_answer(
# in an attached document set were available to all users in the channel.)
bypass_acl = False
if not message_ts_to_respond_to and not is_bot_msg:
if not message_ts_to_respond_to and not is_slash_command:
# if the message is not "/onyx" command, then it should have a message ts to respond to
raise RuntimeError(
"No message timestamp to respond to in `handle_message`. This should never happen."
@@ -316,13 +316,14 @@ def handle_regular_answer(
return True
# Got an answer at this point, can remove reaction and give results
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if not is_slash_command: # Slash commands don't have reactions
update_emote_react(
emoji=DANSWER_REACT_EMOJI,
channel=message_info.channel_to_respond,
message_ts=message_info.msg_to_respond,
remove=True,
client=client,
)
if answer.answer_valid is False:
logger.notice(

View File

@@ -876,12 +876,13 @@ def build_request_details(
sender_id=sender_id,
email=email,
bypass_filters=tagged,
is_bot_msg=False,
is_slash_command=False,
is_bot_dm=event.get("channel_type") == "im",
)
elif req.type == "slash_commands":
channel = req.payload["channel_id"]
channel_name = req.payload["channel_name"]
msg = req.payload["text"]
sender = req.payload["user_id"]
expert_info = expert_info_from_slack_id(
@@ -899,8 +900,8 @@ def build_request_details(
sender_id=sender,
email=email,
bypass_filters=True,
is_bot_msg=True,
is_bot_dm=False,
is_slash_command=True,
is_bot_dm=channel_name == "directmessage",
)
raise RuntimeError("Programming fault, this should never happen.")

View File

@@ -13,7 +13,7 @@ class SlackMessageInfo(BaseModel):
sender_id: str | None
email: str | None
bypass_filters: bool # User has tagged @OnyxBot
is_bot_msg: bool # User is using /OnyxBot
is_slash_command: bool # User is using /OnyxBot
is_bot_dm: bool # User is direct messaging to OnyxBot
@@ -25,7 +25,7 @@ class ActionValuesEphemeralMessageMessageInfo(BaseModel):
email: str | None
sender_id: str | None
thread_messages: list[ThreadMessage] | None
is_bot_msg: bool | None
is_slash_command: bool | None
is_bot_dm: bool | None
thread_to_respond: str | None

View File

@@ -3,7 +3,6 @@ import redis
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_connector_stop import RedisConnectorStop
from onyx.redis.redis_pool import get_redis_client
@@ -31,11 +30,6 @@ class RedisConnector:
tenant_id, cc_pair_id, self.redis
)
def new_index(self, search_settings_id: int) -> RedisConnectorIndex:
return RedisConnectorIndex(
self.tenant_id, self.cc_pair_id, search_settings_id, self.redis
)
@staticmethod
def get_id_from_fence_key(key: str) -> str | None:
"""
@@ -81,3 +75,11 @@ class RedisConnector:
object_id = parts[1]
return object_id
def db_lock_key(self, search_settings_id: int) -> str:
"""
Key for the db lock for an indexing attempt.
Prevents multiple modifications to the current indexing attempt row
from multiple docfetching/docprocessing tasks.
"""
return f"da_lock:indexing:db_{self.cc_pair_id}/{search_settings_id}"

View File

@@ -1,126 +1,10 @@
from datetime import datetime
from typing import cast
import redis
from pydantic import BaseModel
from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
class RedisConnectorIndexPayload(BaseModel):
index_attempt_id: int | None
started: datetime | None
submitted: datetime
celery_task_id: str | None
class RedisConnectorIndex:
"""Manages interactions with redis for indexing tasks. Should only be accessed
through RedisConnector."""
PREFIX = "connectorindexing"
FENCE_PREFIX = f"{PREFIX}_fence" # "connectorindexing_fence"
GENERATOR_TASK_PREFIX = PREFIX + "+generator" # "connectorindexing+generator_fence"
GENERATOR_PROGRESS_PREFIX = (
PREFIX + "_generator_progress"
) # connectorindexing_generator_progress
GENERATOR_COMPLETE_PREFIX = (
PREFIX + "_generator_complete"
) # connectorindexing_generator_complete
GENERATOR_LOCK_PREFIX = "da_lock:indexing:docfetching"
FILESTORE_LOCK_PREFIX = "da_lock:indexing:filestore"
DB_LOCK_PREFIX = "da_lock:indexing:db"
PER_WORKER_LOCK_PREFIX = "da_lock:indexing:per_worker"
TERMINATE_PREFIX = PREFIX + "_terminate" # connectorindexing_terminate
TERMINATE_TTL = 600
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
# used to signal that the watchdog is running
WATCHDOG_PREFIX = PREFIX + "_watchdog"
WATCHDOG_TTL = 300
# used to signal that the connector itself is still running
CONNECTOR_ACTIVE_PREFIX = PREFIX + "_connector_active"
CONNECTOR_ACTIVE_TTL = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT
def __init__(
self,
tenant_id: str,
cc_pair_id: int,
search_settings_id: int,
redis: redis.Redis,
) -> None:
self.tenant_id: str = tenant_id
self.cc_pair_id = cc_pair_id
self.search_settings_id = search_settings_id
self.redis = redis
self.generator_complete_key = (
f"{self.GENERATOR_COMPLETE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.filestore_lock_key = (
f"{self.FILESTORE_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.generator_lock_key = (
f"{self.GENERATOR_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.per_worker_lock_key = (
f"{self.PER_WORKER_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
self.db_lock_key = f"{self.DB_LOCK_PREFIX}_{cc_pair_id}/{search_settings_id}"
self.terminate_key = (
f"{self.TERMINATE_PREFIX}_{cc_pair_id}/{search_settings_id}"
)
def set_generator_complete(self, payload: int | None) -> None:
if not payload:
self.redis.delete(self.generator_complete_key)
return
self.redis.set(self.generator_complete_key, payload)
def generator_clear(self) -> None:
self.redis.delete(self.generator_complete_key)
def get_completion(self) -> int | None:
bytes = self.redis.get(self.generator_complete_key)
if bytes is None:
return None
status = int(cast(int, bytes))
return status
def reset(self) -> None:
self.redis.delete(self.filestore_lock_key)
self.redis.delete(self.db_lock_key)
self.redis.delete(self.generator_lock_key)
self.redis.delete(self.generator_complete_key)
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
# leaving these temporarily for backwards compat, TODO: remove
for key in r.scan_iter(RedisConnectorIndex.CONNECTOR_ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.FILESTORE_LOCK_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_COMPLETE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.GENERATOR_PROGRESS_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorIndex.FENCE_PREFIX + "*"):
r.delete(key)

View File

@@ -1,6 +1,5 @@
from onyx.redis.redis_connector_delete import RedisConnectorDelete
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_connector_prune import RedisConnectorPrune
from onyx.redis.redis_document_set import RedisDocumentSet
from onyx.redis.redis_usergroup import RedisUserGroup
@@ -16,8 +15,6 @@ def is_fence(key_bytes: bytes) -> bool:
return True
if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorIndex.FENCE_PREFIX):
return True
if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX):
return True

View File

@@ -1,3 +1,4 @@
import io
import json
import mimetypes
import os
@@ -101,8 +102,9 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import User
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.file_processing.extract_file_text import convert_docx_to_txt
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.documents.models import AuthStatus
from onyx.server.documents.models import AuthUrl
@@ -124,6 +126,7 @@ from onyx.server.documents.models import IndexAttemptSnapshot
from onyx.server.documents.models import ObjectCreationIdResponse
from onyx.server.documents.models import RunConnectorRequest
from onyx.server.models import StatusResponse
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -438,7 +441,9 @@ def is_zip_file(file: UploadFile) -> bool:
)
def upload_files(files: list[UploadFile]) -> FileUploadResponse:
def upload_files(
files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR
) -> FileUploadResponse:
for file in files:
if not file.filename:
raise HTTPException(status_code=400, detail="File name cannot be empty")
@@ -487,12 +492,17 @@ def upload_files(files: list[UploadFile]) -> FileUploadResponse:
# For mypy, actual check happens at start of function
assert file.filename is not None
# Special handling for docx files - only store the plaintext version
if file.content_type and file.content_type.startswith(
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
docx_file_id = convert_docx_to_txt(file, file_store)
deduped_file_paths.append(docx_file_id)
# Special handling for doc files - only store the plaintext version
file_type = mime_type_to_chat_file_type(file.content_type)
if file_type == ChatFileType.DOC:
extracted_text = extract_file_text(file.file, file.filename or "")
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=file_origin,
file_type="text/plain",
)
deduped_file_paths.append(text_file_id)
deduped_file_names.append(file.filename)
continue
@@ -520,7 +530,7 @@ def upload_files_api(
files: list[UploadFile],
_: User = Depends(current_curator_or_admin_user),
) -> FileUploadResponse:
return upload_files(files)
return upload_files(files, FileOrigin.OTHER)
@router.get("/admin/connector")

View File

@@ -1,6 +1,5 @@
import asyncio
import datetime
import io
import json
import os
import time
@@ -31,7 +30,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.model_configs import LITELLM_PASS_THROUGH_HEADERS
@@ -63,9 +61,7 @@ from onyx.db.models import User
from onyx.db.persona import get_persona_by_id
from onyx.db.user_documents import create_user_files
from onyx.file_processing.extract_file_text import docx_to_txt_filename
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_default_llms
@@ -717,106 +713,65 @@ def upload_files_for_chat(
):
raise HTTPException(
status_code=400,
detail="File size must be less than 20MB",
detail="Images must be less than 20MB",
)
file_store = get_default_file_store()
file_info: list[tuple[str, str | None, ChatFileType]] = []
for file in files:
file_type = mime_type_to_chat_file_type(file.content_type)
file_content = file.file.read() # Read the file content
# NOTE: Image conversion to JPEG used to be enforced here.
# This was removed to:
# 1. Preserve original file content for downloads
# 2. Maintain transparency in formats like PNG
# 3. Ameliorate issue with file conversion
file_content_io = io.BytesIO(file_content)
new_content_type = file.content_type
# Store the file normally
file_id = file_store.save_file(
content=file_content_io,
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=new_content_type or file_type.value,
# 5) Create a user file for each uploaded file
user_files = create_user_files(files, RECENT_DOCS_FOLDER_ID, user, db_session)
for user_file in user_files:
# 6) Create connector
connector_base = ConnectorBase(
name=f"UserFile-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(
db_session=db_session,
connector_data=connector_base,
)
# 4) If the file is a doc, extract text and store that separately
if file_type == ChatFileType.DOC:
# Re-wrap bytes in a fresh BytesIO so we start at position 0
extracted_text_io = io.BytesIO(file_content)
extracted_text = extract_file_text(
file=extracted_text_io, # use the bytes we already read
file_name=file.filename or "",
)
# 7) Create credential
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
text_file_id = file_store.save_file(
content=io.BytesIO(extracted_text.encode()),
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type="text/plain",
)
# Return the text file as the "main" file descriptor for doc types
file_info.append((text_file_id, file.filename, ChatFileType.PLAIN_TEXT))
else:
file_info.append((file_id, file.filename, file_type))
# 5) Create a user file for each uploaded file
user_files = create_user_files([file], RECENT_DOCS_FOLDER_ID, user, db_session)
for user_file in user_files:
# 6) Create connector
connector_base = ConnectorBase(
name=f"UserFile-{int(time.time())}",
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": [user_file.file_id],
"file_names": [user_file.name],
"zip_metadata": {},
},
refresh_freq=None,
prune_freq=None,
indexing_start=None,
)
connector = create_connector(
db_session=db_session,
connector_data=connector_base,
)
# 7) Create credential
credential_info = CredentialBase(
credential_json={},
admin_public=True,
source=DocumentSource.FILE,
curator_public=True,
groups=[],
name=f"UserFileCredential-{int(time.time())}",
is_user_file=True,
)
credential = create_credential(credential_info, user, db_session)
# 8) Create connector credential pair
cc_pair = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
# 8) Create connector credential pair
cc_pair = add_credential_to_connector(
db_session=db_session,
user=user,
connector_id=connector.id,
credential_id=credential.id,
cc_pair_name=f"UserFileCCPair-{int(time.time())}",
access_type=AccessType.PRIVATE,
auto_sync_options=None,
groups=[],
)
user_file.cc_pair_id = cc_pair.data
db_session.commit()
return {
"files": [
{"id": file_id, "type": file_type, "name": file_name}
for file_id, file_name, file_type in file_info
{
"id": user_file.file_id,
"type": mime_type_to_chat_file_type(user_file.content_type),
"name": user_file.name,
}
for user_file in user_files
]
}

View File

@@ -44,12 +44,12 @@ litellm==1.72.2
lxml==5.3.0
lxml_html_clean==0.2.2
Mako==1.2.4
markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2
msal==1.28.0
nltk==3.9.1
Office365-REST-Python-Client==2.5.9
oauthlib==3.2.2
openai==1.75.0
openpyxl==3.0.10
passlib==1.7.4
playwright==1.41.2
psutil==5.9.5
@@ -66,7 +66,7 @@ pypdf==5.4.0
pytest-mock==3.12.0
pytest-playwright==0.7.0
python-docx==1.1.2
python-dotenv==1.0.0
python-dotenv==1.1.1
python-multipart==0.0.20
pywikibot==9.0.0
redis==5.0.8

View File

@@ -22,7 +22,6 @@ from onyx.configs.app_configs import REDIS_SSL
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_index import RedisConnectorIndex
from onyx.redis.redis_pool import RedisPool
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
@@ -130,9 +129,6 @@ def onyx_redis(
logger.info(f"Purging locks associated with deleting cc_pair={cc_pair_id}.")
redis_connector = RedisConnector(tenant_id, cc_pair_id)
match_pattern = f"{tenant_id}:{RedisConnectorIndex.FENCE_PREFIX}_{cc_pair_id}/*"
purge_by_match_and_type(match_pattern, "string", batch, dry_run, r)
redis_delete_if_exists_helper(
f"{tenant_id}:{redis_connector.prune.fence_key}", dry_run, r
)

View File

@@ -11,6 +11,7 @@ import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
def extract_key_value_pairs_to_set(
@@ -35,7 +36,7 @@ def _load_reference_data(
@pytest.fixture
def salesforce_connector() -> SalesforceConnector:
connector = SalesforceConnector(
requested_objects=["Account", "Contact", "Opportunity"],
requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact", "Opportunity"],
)
username = os.environ["SF_USERNAME"]

View File

@@ -21,6 +21,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
FILE_NAME = "Sample.pdf"
FILE_PATH = "tests/integration/common_utils/test_files"
DOCX_FILE_NAME = "three_images.docx"
def test_image_indexing(
@@ -114,3 +115,112 @@ def test_image_indexing(
else:
assert document.image_file_id is not None
assert file_paths[0] in document.image_file_id
def test_docx_image_indexing(
reset: None,
admin_user: DATestUser,
vespa_client: vespa_fixture,
) -> None:
"""Test that images from docx files are correctly extracted and indexed."""
os.makedirs(FILE_PATH, exist_ok=True)
test_file_path = os.path.join(FILE_PATH, DOCX_FILE_NAME)
# Use FileManager to upload the test file
upload_response = FileManager.upload_file_for_connector(
file_path=test_file_path,
file_name=DOCX_FILE_NAME,
user_performing_action=admin_user,
)
LLMProviderManager.create(
name="test_llm_docx",
user_performing_action=admin_user,
)
SettingsManager.update_settings(
DATestSettings(
search_time_image_analysis_enabled=True,
image_extraction_and_analysis_enabled=True,
),
user_performing_action=admin_user,
)
file_paths = upload_response.file_paths
if not file_paths:
pytest.fail("File upload failed - no file paths returned")
# Create a dummy credential for the file connector
credential = CredentialManager.create(
source=DocumentSource.FILE,
credential_json={},
user_performing_action=admin_user,
)
# Create the connector
connector_name = f"DocxFileConnector-{int(datetime.now().timestamp())}"
connector = ConnectorManager.create(
name=connector_name,
source=DocumentSource.FILE,
input_type=InputType.LOAD_STATE,
connector_specific_config={
"file_locations": file_paths,
"file_names": [DOCX_FILE_NAME],
"zip_metadata": {},
},
access_type=AccessType.PUBLIC,
groups=[],
user_performing_action=admin_user,
)
# Link the credential to the connector
cc_pair = CCPairManager.create(
credential_id=credential.id,
connector_id=connector.id,
access_type=AccessType.PUBLIC,
user_performing_action=admin_user,
)
# Explicitly run the connector to start indexing
CCPairManager.run_once(
cc_pair=cc_pair,
from_beginning=True,
user_performing_action=admin_user,
)
CCPairManager.wait_for_indexing_completion(
cc_pair=cc_pair,
after=datetime.now(timezone.utc),
timeout=300,
user_performing_action=admin_user,
)
with get_session_with_current_tenant() as db_session:
# Fetch documents from Vespa - expect text content plus 3 images
documents = DocumentManager.fetch_documents_for_cc_pair(
cc_pair_id=cc_pair.id,
db_session=db_session,
vespa_client=vespa_client,
)
# Should have documents for text content plus 3 images
assert (
len(documents) >= 3
), f"Expected at least 3 documents (3 images), got {len(documents)}"
# Count documents with images
image_documents = [doc for doc in documents if doc.image_file_id is not None]
text_documents = [doc for doc in documents if doc.image_file_id is None]
assert (
len(image_documents) == 3
), f"Expected exactly 3 image documents, got {len(image_documents)}"
assert (
len(text_documents) >= 1
), f"Expected at least 1 text document, got {len(text_documents)}"
# Verify each image document has a valid image_file_id pointing to our uploaded file
for image_doc in image_documents:
assert file_paths[0] in (
image_doc.image_file_id or ""
), f"Image document should reference uploaded file: {image_doc.image_file_id}"

View File

@@ -0,0 +1,129 @@
#!/usr/bin/env python3
"""
Test script for the new custom query configuration functionality in SalesforceConnector.
This demonstrates how to use the new custom_query_config parameter to specify
exactly which fields and associations (child objects) to retrieve for each object type.
"""
import json
from typing import Any
from onyx.connectors.salesforce.connector import _validate_custom_query_config
from onyx.connectors.salesforce.connector import SalesforceConnector
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
def test_custom_query_config() -> None:
"""Test the custom query configuration functionality."""
# Example custom query configuration
# This specifies exactly which fields and associations to retrieve
custom_config = {
ACCOUNT_OBJECT_TYPE: {
"fields": ["Id", "Name", "Industry", "CreatedDate", MODIFIED_FIELD],
"associations": {
"Contact": ["Id", "FirstName", "LastName", "Email"],
"Opportunity": ["Id", "Name", "StageName", "Amount", "CloseDate"],
},
},
"Lead": {
"fields": ["Id", "FirstName", "LastName", "Company", "Status"],
"associations": {}, # No associations for Lead
},
}
# Create connector with custom configuration
connector = SalesforceConnector(
batch_size=50, custom_query_config=json.dumps(custom_config)
)
print("✅ SalesforceConnector created successfully with custom query config")
print(f"Parent object list: {connector.parent_object_list}")
print(f"Custom config keys: {list(custom_config.keys())}")
# Test that the parent object list is derived from the custom config
assert connector.parent_object_list == [ACCOUNT_OBJECT_TYPE, "Lead"]
assert connector.custom_query_config == custom_config
print("✅ Basic validation passed")
def test_traditional_config() -> None:
"""Test that the traditional requested_objects approach still works."""
# Traditional approach
connector = SalesforceConnector(
batch_size=50, requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact"]
)
print("✅ SalesforceConnector created successfully with traditional config")
print(f"Parent object list: {connector.parent_object_list}")
# Test that it still works the old way
assert connector.parent_object_list == [ACCOUNT_OBJECT_TYPE, "Contact"]
assert connector.custom_query_config is None
print("✅ Traditional config validation passed")
def test_validation() -> None:
"""Test that invalid configurations are rejected."""
# Test invalid config structure
invalid_configs: list[Any] = [
# Invalid fields type
{ACCOUNT_OBJECT_TYPE: {"fields": "invalid"}},
# Invalid associations type
{ACCOUNT_OBJECT_TYPE: {"associations": "invalid"}},
# Nested invalid structure
{ACCOUNT_OBJECT_TYPE: {"associations": {"Contact": {"fields": "invalid"}}}},
]
for i, invalid_config in enumerate(invalid_configs):
try:
_validate_custom_query_config(invalid_config)
assert False, f"Should have raised ValueError for invalid_config[{i}]"
except ValueError:
print(f"✅ Correctly rejected invalid config {i}")
if __name__ == "__main__":
print("Testing SalesforceConnector custom query configuration...")
print("=" * 60)
test_custom_query_config()
print()
test_traditional_config()
print()
test_validation()
print()
print("=" * 60)
print("🎉 All tests passed! The custom query configuration is working correctly.")
print()
print("Example usage:")
print(
"""
# Custom configuration approach
custom_config = {
ACCOUNT_OBJECT_TYPE: {
"fields": ["Id", "Name", "Industry"],
"associations": {
"Contact": {
"fields": ["Id", "FirstName", "LastName", "Email"],
"associations": {}
}
}
}
}
connector = SalesforceConnector(custom_query_config=custom_config)
# Traditional approach (still works)
connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE, "Contact"])
"""
)

View File

@@ -26,6 +26,9 @@ from onyx.connectors.salesforce.salesforce_calls import _make_time_filter_for_sf
from onyx.connectors.salesforce.salesforce_calls import _make_time_filtered_query
from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query
from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite
from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE
from onyx.connectors.salesforce.utils import MODIFIED_FIELD
from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE
from onyx.utils.logger import setup_logger
# from onyx.connectors.salesforce.onyx_salesforce_type import OnyxSalesforceType
@@ -153,7 +156,7 @@ def _create_csv_file_and_update_db(
Creates a CSV file for the given object type and records.
Args:
object_type: The Salesforce object type (e.g. "Account", "Contact")
object_type: The Salesforce object type (e.g. ACCOUNT_OBJECT_TYPE, "Contact")
records: List of dictionaries containing the record data
filename: Name of the CSV file to create (default: test_data.csv)
"""
@@ -184,7 +187,7 @@ def _create_csv_with_example_data(sf_db: OnyxSalesforceSQLite) -> None:
Creates CSV files with example data, organized by object type.
"""
example_data: dict[str, list[dict]] = {
"Account": [
ACCOUNT_OBJECT_TYPE: [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Acme Inc.",
@@ -428,7 +431,7 @@ def _test_query(sf_db: OnyxSalesforceSQLite) -> None:
}
# Get all Account IDs
account_ids = sf_db.find_ids_by_type("Account")
account_ids = sf_db.find_ids_by_type(ACCOUNT_OBJECT_TYPE)
# Verify we found all expected accounts
assert len(account_ids) == len(
@@ -480,7 +483,9 @@ def _test_upsert(sf_db: OnyxSalesforceSQLite) -> None:
},
]
_create_csv_file_and_update_db(sf_db, "Account", update_data, "update_data.csv")
_create_csv_file_and_update_db(
sf_db, ACCOUNT_OBJECT_TYPE, update_data, "update_data.csv"
)
# Verify the update worked
updated_record = sf_db.get_record(_VALID_SALESFORCE_IDS[0])
@@ -573,7 +578,7 @@ def _test_account_with_children(sf_db: OnyxSalesforceSQLite) -> None:
3. Child object data is complete and accurate
"""
# First get all account IDs
account_ids = sf_db.find_ids_by_type("Account")
account_ids = sf_db.find_ids_by_type(ACCOUNT_OBJECT_TYPE)
assert len(account_ids) > 0, "No accounts found"
# For each account, get its children and verify the data
@@ -690,7 +695,7 @@ def _test_get_affected_parent_ids(sf_db: OnyxSalesforceSQLite) -> None:
"""
# Create test data with relationships
test_data = {
"Account": [
ACCOUNT_OBJECT_TYPE: [
{
"Id": _VALID_SALESFORCE_IDS[0],
"Name": "Parent Account 1",
@@ -720,40 +725,46 @@ def _test_get_affected_parent_ids(sf_db: OnyxSalesforceSQLite) -> None:
# Test Case 1: Account directly in updated_ids and parent_types
updated_ids = [_VALID_SALESFORCE_IDS[1]] # Parent Account 2
parent_types = set(["Account"])
parent_types = set([ACCOUNT_OBJECT_TYPE])
affected_ids_by_type = defaultdict(set)
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
updated_ids, parent_types
):
affected_ids_by_type[parent_type].add(parent_id)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[1] in affected_ids_by_type["Account"]
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
), "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[1] in affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
), "Direct parent ID not included"
# Test Case 2: Account with child in updated_ids
updated_ids = [_VALID_SALESFORCE_IDS[40]] # Child Contact
parent_types = set(["Account"])
parent_types = set([ACCOUNT_OBJECT_TYPE])
affected_ids_by_type = defaultdict(set)
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
updated_ids, parent_types
):
affected_ids_by_type[parent_type].add(parent_id)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids_by_type["Account"]
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
), "Account type not in affected_ids_by_type"
assert (
_VALID_SALESFORCE_IDS[0] in affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
), "Parent of updated child not included"
# Test Case 3: Both direct and indirect affects
updated_ids = [_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]] # Both cases
parent_types = set(["Account"])
parent_types = set([ACCOUNT_OBJECT_TYPE])
affected_ids_by_type = defaultdict(set)
for parent_type, parent_id, _ in sf_db.get_changed_parent_ids_by_type(
updated_ids, parent_types
):
affected_ids_by_type[parent_type].add(parent_id)
assert "Account" in affected_ids_by_type, "Account type not in affected_ids_by_type"
affected_ids = affected_ids_by_type["Account"]
assert (
ACCOUNT_OBJECT_TYPE in affected_ids_by_type
), "Account type not in affected_ids_by_type"
affected_ids = affected_ids_by_type[ACCOUNT_OBJECT_TYPE]
assert len(affected_ids) == 2, "Expected exactly two affected parent IDs"
assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included"
assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included"
@@ -929,7 +940,7 @@ def _get_child_records_by_id_query(
object_id: str,
sf_type: str,
child_relationships: list[str],
relationships_to_fields: dict[str, list[str]],
relationships_to_fields: dict[str, set[str]],
) -> str:
"""Returns a SOQL query given the object id, type and child relationships.
@@ -963,7 +974,7 @@ def test_salesforce_connector_single() -> None:
# this record has some opportunity child records
parent_id = "001bm00000BXfhEAAT"
parent_type = "Account"
parent_type = ACCOUNT_OBJECT_TYPE
parent_types = [parent_type]
username = os.environ["SF_USERNAME"]
@@ -987,11 +998,11 @@ def test_salesforce_connector_single() -> None:
child_to_parent_types: dict[str, set[str]] = (
{}
) # reverse map from child to parent types
child_relationship_to_queryable_fields: dict[str, list[str]] = {}
child_relationship_to_queryable_fields: dict[str, set[str]] = {}
# parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {}
# Step 1 - make a list of all the types to download (parent + direct child + "User")
# Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE)
logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}")
for parent_type_working in parent_types:
child_types_working = sf_client.get_children_of_sf_type(parent_type_working)
@@ -1035,8 +1046,8 @@ def test_salesforce_connector_single() -> None:
result = sf_client.query(query)
records = result["records"]
record = records[0]
assert record["attributes"]["type"] == "Account"
parent_last_modified_date = record.get("LastModifiedDate", "")
assert record["attributes"]["type"] == ACCOUNT_OBJECT_TYPE
parent_last_modified_date = record.get(MODIFIED_FIELD, "")
parent_semantic_identifier = record.get("Name", "Unknown Object")
parent_last_modified_by_id = record.get("LastModifiedById")
@@ -1163,9 +1174,9 @@ def test_salesforce_connector_single() -> None:
# get user relationship if present
primary_owner_list = None
if parent_last_modified_by_id:
queryable_user_fields = sf_client.get_queryable_fields_by_type("User")
queryable_user_fields = sf_client.get_queryable_fields_by_type(USER_OBJECT_TYPE)
query = get_object_by_id_query(
parent_last_modified_by_id, "User", queryable_user_fields
parent_last_modified_by_id, USER_OBJECT_TYPE, queryable_user_fields
)
result = sf_client.query(query)
user_record = result["records"][0]

View File

@@ -4,7 +4,7 @@ dependencies:
version: 14.3.1
- name: vespa
repository: https://onyx-dot-app.github.io/vespa-helm-charts
version: 0.2.23
version: 0.2.24
- name: nginx
repository: oci://registry-1.docker.io/bitnamicharts
version: 15.14.0
@@ -14,5 +14,5 @@ dependencies:
- name: minio
repository: oci://registry-1.docker.io/bitnamicharts
version: 17.0.4
digest: sha256:4c938cf9138e4ff6f5ecac5c044324d508ef2b0e1a23ba3f2bc089015cb40ff6
generated: "2025-06-16T18:53:19.63168-07:00"
digest: sha256:dddd687525764f5698adc339a11d268b0ee9c3ca81f8d46c9e65a6bf2c21cf25
generated: "2025-08-06T19:00:41.218513-07:00"

View File

@@ -5,7 +5,7 @@ home: https://www.onyx.app/
sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.2.3
version: 0.2.5
appVersion: latest
annotations:
category: Productivity
@@ -23,7 +23,7 @@ dependencies:
repository: https://charts.bitnami.com/bitnami
condition: postgresql.enabled
- name: vespa
version: 0.2.23
version: 0.2.24
repository: https://onyx-dot-app.github.io/vespa-helm-charts
condition: vespa.enabled
- name: nginx

View File

@@ -7,6 +7,7 @@ import Title from "@/components/ui/title";
import { Button } from "@/components/ui/button";
import Link from "next/link";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import {
Tooltip,
TooltipContent,
@@ -136,6 +137,7 @@ function SourceTileTooltipWrapper({
export default function Page() {
const sources = useMemo(() => listSourceMetadata(), []);
const [searchTerm, setSearchTerm] = useState("");
const { data: federatedConnectors } = useFederatedConnectors();

View File

@@ -2834,7 +2834,7 @@ export function ChatPage({
currentSessionChatState == "input" &&
!loadingError &&
!submittedMessage && (
<div className="h-full w-[95%] mx-auto flex flex-col justify-center items-center">
<div className="h-full w-[95%] mx-auto flex flex-col justify-center items-center">
<ChatIntro selectedPersona={liveAssistant} />
{currentPersona && (

View File

@@ -37,6 +37,7 @@ import { transformLinkUri } from "@/lib/utils";
import FileInput from "@/app/admin/connectors/[connector]/pages/ConnectorInput/FileInput";
import { DatePicker } from "./ui/datePicker";
import { Textarea, TextareaProps } from "./ui/textarea";
import { RichTextSubtext } from "./RichTextSubtext";
export function SectionHeader({
children,
@@ -82,8 +83,25 @@ export function LabelWithTooltip({
}
export function SubLabel({ children }: { children: string | JSX.Element }) {
// Add whitespace-pre-wrap for multiline descriptions (when children is a string with newlines)
const hasNewlines = typeof children === "string" && children.includes("\n");
// If children is a string, use RichTextSubtext to parse and render links
if (typeof children === "string") {
return (
<div className="text-sm text-neutral-600 dark:text-neutral-300 mb-2">
<RichTextSubtext
text={children}
className={hasNewlines ? "whitespace-pre-wrap" : ""}
/>
</div>
);
}
return (
<div className="text-sm text-neutral-600 dark:text-neutral-300 mb-2">
<div
className={`text-sm text-neutral-600 dark:text-neutral-300 mb-2 ${hasNewlines ? "whitespace-pre-wrap" : ""}`}
>
{children}
</div>
);

View File

@@ -0,0 +1,85 @@
import React from "react";
interface RichTextSubtextProps {
text: string;
className?: string;
}
/**
* Component that renders text with clickable links.
* Detects URLs in the text and converts them to clickable links.
* Also supports markdown-style links like [text](url).
*/
export function RichTextSubtext({
text,
className = "",
}: RichTextSubtextProps) {
// Function to parse text and create React elements
const parseText = (input: string): React.ReactNode[] => {
const elements: React.ReactNode[] = [];
// Regex to match markdown links [text](url) and plain URLs
const combinedRegex = /(\[([^\]]+)\]\(([^)]+)\))|(https?:\/\/[^\s]+)/g;
let lastIndex = 0;
let match;
let key = 0;
while ((match = combinedRegex.exec(input)) !== null) {
// Add text before the match
if (match.index > lastIndex) {
elements.push(
<span key={`text-${key++}`}>
{input.slice(lastIndex, match.index)}
</span>
);
}
if (match[1]) {
// Markdown-style link [text](url)
const linkText = match[2];
const url = match[3];
elements.push(
<a
key={`link-${key++}`}
href={url}
target="_blank"
rel="noopener noreferrer"
className="text-link hover:text-link-hover underline"
onClick={(e) => e.stopPropagation()}
>
{linkText}
</a>
);
} else if (match[4]) {
// Plain URL
const url = match[4];
elements.push(
<a
key={`link-${key++}`}
href={url}
target="_blank"
rel="noopener noreferrer"
className="text-link hover:text-link-hover underline"
onClick={(e) => e.stopPropagation()}
>
{url}
</a>
);
}
lastIndex = match.index + match[0].length;
}
// Add remaining text after the last match
if (lastIndex < input.length) {
elements.push(
<span key={`text-${key++}`}>{input.slice(lastIndex)}</span>
);
}
return elements;
};
return <div className={className}>{parseText(text)}</div>;
}

View File

@@ -15,7 +15,7 @@ export function StarterMessages({
<div
key={-4}
className={`
short:hidden
very-short:hidden
mx-auto
w-full
${

View File

@@ -599,14 +599,56 @@ export const connectorConfigs: Record<
description: "Configure Salesforce connector",
values: [
{
type: "list",
query: "Enter requested objects:",
label: "Requested Objects",
name: "requested_objects",
type: "tab",
name: "salesforce_config_type",
label: "Configuration Type",
optional: true,
description: `Specify the Salesforce object types you want us to index. If unsure, don't specify any objects and Onyx will default to indexing by 'Account'.
tabs: [
{
value: "simple",
label: "Simple",
fields: [
{
type: "list",
query: "Enter requested objects:",
label: "Requested Objects",
name: "requested_objects",
optional: true,
description: `Specify the Salesforce object types you want us to index. If unsure, don't specify any objects and Onyx will default to indexing by 'Account'.
Hint: Use the singular form of the object name (e.g., 'Opportunity' instead of 'Opportunities').`,
},
],
},
{
value: "advanced",
label: "Advanced",
fields: [
{
type: "text",
query: "Enter custom query config:",
label: "Custom Query Config",
name: "custom_query_config",
optional: true,
isTextArea: true,
description: `Enter a JSON configuration that precisely defines which fields and child objects to index. This gives you complete control over the data structure.
Example:
{
"Account": {
"fields": ["Id", "Name", "Industry"],
"associations": {
"Contact": ["Id", "FirstName", "LastName", "Email"]
}
}
}
[See our docs](https://docs.onyx.app/connectors/salesforce) for more details.`,
},
],
},
],
defaultTab: "simple",
},
],
advanced_values: [],