Compare commits

..

1 Commits

Author SHA1 Message Date
pablonyx
575cfa5224 fix 2025-03-27 11:16:58 -07:00
61 changed files with 474 additions and 846 deletions

View File

@@ -9,10 +9,6 @@ on:
- cron: "0 16 * * *"
env:
# AWS
AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS }}
AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS: ${{ secrets.AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS }}
# Confluence
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_TEST_SPACE: ${{ secrets.CONFLUENCE_TEST_SPACE }}

View File

@@ -93,12 +93,12 @@ def _get_access_for_documents(
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
is_public=is_public_anywhere,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
)
return access_map

View File

@@ -2,6 +2,7 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
@@ -31,6 +32,8 @@ def gather_stream_for_answer_api(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
elif isinstance(packet, OnyxContexts):
response.contexts = packet
if answer:
response.answer = answer

View File

@@ -25,10 +25,6 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/onyx/configs/saml_co
#####
# Auto Permission Sync
#####
DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# In seconds, default is 5 minutes
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
@@ -43,7 +39,6 @@ CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = (
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
@@ -77,13 +72,6 @@ OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get(
"OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", ""
)
GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
)
SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int(
os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
)
# The posthog client does not accept empty API keys or hosts however it fails silently
# when the capture is called. These defaults prevent Posthog issues from breaking the Onyx app

View File

@@ -3,8 +3,6 @@ from collections.abc import Generator
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY
from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync
from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync
@@ -68,13 +66,13 @@ GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC: set[DocumentSource] = {
DOC_PERMISSION_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all doc permissions every 5 minutes
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: SLACK_PERMISSION_DOC_SYNC_FREQUENCY,
DocumentSource.SLACK: 5 * 60,
}
# If nothing is specified here, we run the doc_sync every time the celery beat runs
EXTERNAL_GROUP_SYNC_PERIODS: dict[DocumentSource, int] = {
# Polling is not supported so we fetch all group permissions every 30 minutes
DocumentSource.GOOGLE_DRIVE: GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY,
DocumentSource.GOOGLE_DRIVE: 5 * 60,
DocumentSource.CONFLUENCE: CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY,
}

View File

@@ -14,6 +14,7 @@ from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
@@ -55,6 +56,25 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _translate_doc_response_to_simple_doc(
doc_response: QADocsResponse,
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,
match_highlights=[
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
@@ -91,6 +111,9 @@ def _convert_packet_stream_to_response(
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)

View File

@@ -8,6 +8,7 @@ from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswer
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
@@ -163,6 +164,8 @@ class ChatBasicResponse(BaseModel):
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
# TODO: deprecate both of these
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
@@ -217,3 +220,4 @@ class OneShotQAResponse(BaseModel):
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
contexts: OnyxContexts | None = None

View File

@@ -505,8 +505,11 @@ async def setup_tenant(tenant_id: str) -> None:
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Run Alembic migrations in a way that isolates it from the current event loop
# Create a new event loop for this synchronous operation
loop = asyncio.get_event_loop()
# Use run_in_executor which properly isolates the thread execution
await loop.run_in_executor(None, lambda: run_alembic_migrations(tenant_id))
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:

View File

@@ -18,7 +18,7 @@ def _get_access_for_document(
document_id=document_id,
)
doc_access = DocumentAccess.build(
return DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
@@ -26,8 +26,6 @@ def _get_access_for_document(
is_public=info[2] if info else False,
)
return doc_access
def get_access_for_document(
document_id: str,
@@ -40,12 +38,12 @@ def get_access_for_document(
def get_null_document_access() -> DocumentAccess:
return DocumentAccess.build(
user_emails=[],
user_groups=[],
return DocumentAccess(
user_emails=set(),
user_groups=set(),
is_public=False,
external_user_emails=[],
external_user_group_ids=[],
external_user_emails=set(),
external_user_group_ids=set(),
)
@@ -58,18 +56,18 @@ def _get_access_for_documents(
document_ids=document_ids,
)
doc_access = {
document_id: DocumentAccess.build(
user_emails=[email for email in user_emails if email],
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
# MIT version will wipe all groups and external groups on update
user_groups=[],
user_groups=set(),
is_public=is_public,
external_user_emails=[],
external_user_group_ids=[],
external_user_emails=set(),
external_user_group_ids=set(),
)
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not been indexed by the indexing job yet, in those cases
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.

View File

@@ -56,45 +56,33 @@ class DocExternalAccess:
)
@dataclass(frozen=True, init=False)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
external_user_emails: set[str]
external_user_group_ids: set[str]
is_public: bool
def __init__(self) -> None:
raise TypeError(
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:
acl_set.add(prefix_user_email(user_email))
for group_name in self.user_groups:
acl_set.add(prefix_user_group(group_name))
for external_user_email in self.external_user_emails:
acl_set.add(prefix_user_email(external_user_email))
for external_group_id in self.external_user_group_ids:
acl_set.add(prefix_external_group(external_group_id))
if self.is_public:
acl_set.add(PUBLIC_DOC_PAT)
return acl_set
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(
@@ -105,32 +93,29 @@ class DocumentAccess(ExternalAccess):
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
obj = object.__new__(cls)
object.__setattr__(
obj, "user_emails", {user_email for user_email in user_emails if user_email}
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
)
object.__setattr__(obj, "user_groups", set(user_groups))
object.__setattr__(
obj,
"external_user_emails",
{external_email for external_email in external_user_emails},
)
object.__setattr__(
obj,
"external_user_group_ids",
{external_group_id for external_group_id in external_user_group_ids},
)
object.__setattr__(obj, "is_public", is_public)
return obj
default_public_access = DocumentAccess.build(
external_user_emails=[],
external_user_group_ids=[],
user_emails=[],
user_groups=[],
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
is_public=True,
)

View File

@@ -7,6 +7,7 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
@@ -23,7 +24,7 @@ def process_llm_stream(
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")

View File

@@ -156,6 +156,7 @@ def generate_initial_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -183,6 +183,7 @@ def generate_validate_refined_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -57,6 +57,7 @@ def format_results(
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -13,7 +13,9 @@ from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -57,7 +59,9 @@ def basic_use_tool_response(
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(section_to_llm_doc(section))
initial_search_results.append(
context_from_inference_section(section)
)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -1,5 +1,6 @@
from datetime import timedelta
from typing import Any
from typing import cast
from celery import Celery
from celery import signals
@@ -9,10 +10,12 @@ from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import SqlEngine
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST
from shared_configs.configs import MULTI_TENANT
@@ -138,6 +141,8 @@ class DynamicTenantScheduler(PersistentScheduler):
"""Only updates the actual beat schedule on the celery app when it changes"""
do_update = False
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
task_logger.debug("_try_updating_schedule starting")
tenant_ids = get_all_tenant_ids()
@@ -147,7 +152,16 @@ class DynamicTenantScheduler(PersistentScheduler):
current_schedule = self.schedule.items()
# get potential new state
beat_multiplier = OnyxRuntime.get_beat_multiplier()
beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
task_logger.error(
f"Invalid beat_multiplier value: {beat_multiplier_raw}"
)
new_schedule = self._generate_schedule(tenant_ids, beat_multiplier)

View File

@@ -21,7 +21,6 @@ BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds)
# we have a better implementation (backpressure, etc)
# Note that DynamicTenantScheduler can adjust the runtime value for this via Redis
CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0
# tasks that run in either self-hosted on cloud
beat_task_templates: list[dict] = []

View File

@@ -389,8 +389,6 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
credential_id_to_delete: int | None = None
connector_id_to_delete: int | None = None
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -445,35 +443,26 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# Flush to ensure document deletion happens before connector deletion
db_session.flush()
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id_to_delete,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
@@ -506,15 +495,15 @@ def monitor_connector_deletion_taskset(
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={connector_id_to_delete} "
f"credential={credential_id_to_delete} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={fence_data.num_tasks}"
)
@@ -564,7 +553,7 @@ def validate_connector_deletion_fences(
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_upsert_tasks: set[str],
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
@@ -651,7 +640,7 @@ def validate_connector_deletion_fence(
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_upsert_tasks:
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1

View File

@@ -17,7 +17,6 @@ from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.document import upsert_document_external_perms
from ee.onyx.external_permissions.sync_params import DOC_PERMISSION_SYNC_PERIODS
@@ -64,7 +63,6 @@ from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyn
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import format_error_for_logging
@@ -108,10 +106,9 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
source_sync_period = DOC_PERMISSION_SYNC_PERIODS.get(cc_pair.connector.source)
# If RESTRICTED_FETCH_PERIOD[source] is None, we always run the sync.
if not source_sync_period:
source_sync_period = DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY
source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier())
return True
# If the last sync is greater than the full fetch period, we run the sync
next_sync = last_perm_sync + timedelta(seconds=source_sync_period)

View File

@@ -72,7 +72,6 @@ from onyx.redis.redis_pool import get_redis_replica_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.redis.redis_utils import is_fence
from onyx.server.runtime.onyx_runtime import OnyxRuntime
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
@@ -402,11 +401,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
logger.warning(f"Adding {key_bytes} to the lookup table.")
redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes)
redis_client.set(
OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE,
1,
ex=OnyxRuntime.get_build_fence_lookup_table_interval(),
)
redis_client.set(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=300)
# 1/3: KICKOFF

View File

@@ -194,6 +194,17 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class OnyxContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str
class OnyxContexts(BaseModel):
contexts: list[OnyxContext]
class OnyxAnswer(BaseModel):
answer: str | None
@@ -259,6 +270,7 @@ class PersonaOverrideConfig(BaseModel):
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| OnyxContexts
| FileChatDisplay
| CustomToolResponse
| StreamingError

View File

@@ -29,6 +29,7 @@ from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
@@ -130,6 +131,7 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -298,6 +300,7 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| OnyxContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -916,6 +919,8 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(OnyxContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:

View File

@@ -301,10 +301,6 @@ def prune_sections(
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)

View File

@@ -3,6 +3,7 @@ from collections.abc import Sequence
from pydantic import BaseModel
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceChunk
@@ -11,7 +12,7 @@ class DocumentIdOrderMapping(BaseModel):
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
) -> DocumentIdOrderMapping:
order_mapping = {}
current = 1 if one_indexed else 0

View File

@@ -382,7 +382,6 @@ ONYX_CLOUD_TENANT_ID = "cloud"
# the redis namespace for runtime variables
ONYX_CLOUD_REDIS_RUNTIME = "runtime"
CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600
class OnyxCeleryTask:

View File

@@ -87,7 +87,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
credentials.get(key)
for key in ["aws_access_key_id", "aws_secret_access_key"]
):
raise ConnectorMissingCredentialError("Amazon S3")
raise ConnectorMissingCredentialError("Google Cloud Storage")
session = boto3.Session(
aws_access_key_id=credentials["aws_access_key_id"],

View File

@@ -28,9 +28,8 @@ from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import read_text_file
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -70,9 +69,7 @@ def _process_egnyte_file(
file_name = file_metadata["name"]
extension = get_file_ext(file_name)
if not is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
if not is_valid_file_ext(extension):
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return None

View File

@@ -22,9 +22,8 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_accepted_file_ext
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import OnyxExtensionType
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.utils.logger import setup_logger
@@ -52,7 +51,7 @@ def _read_files_and_metadata(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), subfile, metadata
elif is_accepted_file_ext(extension, OnyxExtensionType.All):
elif is_valid_file_ext(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
@@ -123,7 +122,7 @@ def _process_file(
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
return []
if not is_accepted_file_ext(extension, OnyxExtensionType.All):
if not is_valid_file_ext(extension):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)

View File

@@ -28,9 +28,7 @@ from onyx.connectors.google_drive.doc_conversion import (
)
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import DriveRetrievalStage
@@ -88,18 +86,13 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
# We used to always get the user email from the file owners when available,
# but this was causing issues with shared folders where the owner was not included in the service account
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
# wanting to retry with file owners and/or admin email at some point.
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = retriever_email
# Only construct these services when needed
user_drive_service = lazy_eval(
lambda: get_drive_service(creds, user_email=user_email)
@@ -457,11 +450,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from add_retrieval_info(
get_all_files_in_my_drive_and_shared(
get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
include_shared_with_me=self.include_files_shared_with_me,
start=curr_stage.completed_until if resuming else start,
end=end,
),
@@ -924,28 +916,20 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[RetrievedDriveFile] = []
files_batch: list[GoogleDriveFileType] = []
def _yield_batch(
files_batch: list[RetrievedDriveFile],
files_batch: list[GoogleDriveFileType],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [
(
convert_func,
(
file.user_email,
file.drive_file,
),
)
for file in files_batch
]
func_with_args = [(convert_func, (file,)) for file in files_batch]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
@@ -983,7 +967,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
continue
files_batch.append(retrieved_file)
files_batch.append(retrieved_file.drive_file)
if len(files_batch) < self.batch_size:
continue

View File

@@ -87,17 +87,35 @@ def _download_and_extract_sections_basic(
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
@@ -106,100 +124,88 @@ def _download_and_extract_sections_basic(
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
logger.warning(f"Failed to download {file_name}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
def convert_drive_item_to_document(

View File

@@ -214,11 +214,10 @@ def get_files_in_shared_drive(
yield file
def get_all_files_in_my_drive_and_shared(
def get_all_files_in_my_drive(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool,
include_shared_with_me: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
@@ -230,8 +229,7 @@ def get_all_files_in_my_drive_and_shared(
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -248,8 +246,7 @@ def get_all_files_in_my_drive_and_shared(
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,

View File

@@ -20,8 +20,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ALL_ACCEPTED_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import VALID_FILE_EXTENSIONS
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -298,7 +298,7 @@ class HighspotConnector(LoadConnector, PollConnector, SlimConnector):
elif (
is_valid_format
and file_extension in ALL_ACCEPTED_FILE_EXTENSIONS
and file_extension in VALID_FILE_EXTENSIONS
and can_download
):
# For documents, try to get the text content

View File

@@ -339,12 +339,6 @@ class SearchPipeline:
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def merged_retrieved_sections(self) -> list[InferenceSection]:
"""Should be used to display in the UI in order to prevent displaying
multiple sections for the same document as separate "documents"."""
return _merge_sections(sections=self.retrieved_sections)
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -421,10 +415,6 @@ class SearchPipeline:
raise ValueError(
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
)
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
# since the property sets the generator. DO NOT REMOVE.
_ = self.final_context_sections
self._section_relevance = next(
cast(
Iterator[list[SectionRelevancePiece]],

View File

@@ -7,8 +7,6 @@ from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Sequence
from email.parser import Parser as EmailParser
from enum import auto
from enum import IntFlag
from io import BytesIO
from pathlib import Path
from typing import Any
@@ -37,7 +35,7 @@ logger = setup_logger()
TEXT_SECTION_SEPARATOR = "\n\n"
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
PLAIN_TEXT_FILE_EXTENSIONS = [
".txt",
".md",
".mdx",
@@ -51,7 +49,7 @@ ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS = [
".yaml",
]
ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
VALID_FILE_EXTENSIONS = PLAIN_TEXT_FILE_EXTENSIONS + [
".pdf",
".docx",
".pptx",
@@ -59,21 +57,12 @@ ACCEPTED_DOCUMENT_FILE_EXTENSIONS = [
".eml",
".epub",
".html",
]
ACCEPTED_IMAGE_FILE_EXTENSIONS = [
".png",
".jpg",
".jpeg",
".webp",
]
ALL_ACCEPTED_FILE_EXTENSIONS = (
ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
+ ACCEPTED_DOCUMENT_FILE_EXTENSIONS
+ ACCEPTED_IMAGE_FILE_EXTENSIONS
)
IMAGE_MEDIA_TYPES = [
"image/png",
"image/jpeg",
@@ -81,15 +70,8 @@ IMAGE_MEDIA_TYPES = [
]
class OnyxExtensionType(IntFlag):
Plain = auto()
Document = auto()
Multimedia = auto()
All = Plain | Document | Multimedia
def is_text_file_extension(file_name: str) -> bool:
return any(file_name.endswith(ext) for ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS)
return any(file_name.endswith(ext) for ext in PLAIN_TEXT_FILE_EXTENSIONS)
def get_file_ext(file_path_or_name: str | Path) -> str:
@@ -101,20 +83,8 @@ def is_valid_media_type(media_type: str) -> bool:
return media_type in IMAGE_MEDIA_TYPES
def is_accepted_file_ext(ext: str, ext_type: OnyxExtensionType) -> bool:
if ext_type & OnyxExtensionType.Plain:
if ext in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Document:
if ext in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
return True
if ext_type & OnyxExtensionType.Multimedia:
if ext in ACCEPTED_IMAGE_FILE_EXTENSIONS:
return True
return False
def is_valid_file_ext(ext: str) -> bool:
return ext in VALID_FILE_EXTENSIONS
def is_text_file(file: IO[bytes]) -> bool:
@@ -412,9 +382,6 @@ def extract_file_text(
"""
Legacy function that returns *only text*, ignoring embedded images.
For backward-compatibility in code that only wants text.
NOTE: Ignoring seems to be defined as returning an empty string for files it can't
handle (such as images).
"""
extension_to_function: dict[str, Callable[[IO[Any]], str]] = {
".pdf": pdf_to_text,
@@ -438,9 +405,7 @@ def extract_file_text(
if extension is None:
extension = get_file_ext(file_name)
if is_accepted_file_ext(
extension, OnyxExtensionType.Plain | OnyxExtensionType.Document
):
if is_valid_file_ext(extension):
func = extension_to_function.get(extension, file_io_to_text)
file.seek(0)
return func(file)

View File

@@ -224,6 +224,27 @@ class Chunker:
)
chunks_list.append(new_chunk)
def _chunk_document(
self,
document: IndexingDocument,
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
"""
Legacy method for backward compatibility.
Calls _chunk_document_with_sections with document.sections.
"""
return self._chunk_document_with_sections(
document,
document.processed_sections,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
content_token_limit,
)
def _chunk_document_with_sections(
self,
document: IndexingDocument,
@@ -243,7 +264,7 @@ class Chunker:
for section_idx, section in enumerate(sections):
# Get section text and other attributes
section_text = clean_text(str(section.text or ""))
section_text = clean_text(section.text or "")
section_link_text = section.link or ""
image_url = section.image_file_name

View File

@@ -439,7 +439,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
**document.dict(),
processed_sections=[
Section(
text=section.text if isinstance(section, TextSection) else "",
text=section.text if isinstance(section, TextSection) else None,
link=section.link,
image_file_name=section.image_file_name
if isinstance(section, ImageSection)
@@ -459,11 +459,11 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
for section in document.sections:
# For ImageSection, process and create base Section with both text and image_file_name
if isinstance(section, ImageSection):
# Default section with image path preserved - ensure text is always a string
# Default section with image path preserved
processed_section = Section(
link=section.link,
image_file_name=section.image_file_name,
text="", # Initialize with empty string
text=None, # Will be populated if summarization succeeds
)
# Try to get image summary
@@ -506,21 +506,13 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
# For TextSection, create a base Section with text and link
elif isinstance(section, TextSection):
processed_section = Section(
text=section.text or "", # Ensure text is always a string, not None
link=section.link,
image_file_name=None,
text=section.text, link=section.link, image_file_name=None
)
processed_sections.append(processed_section)
# If it's already a base Section (unlikely), just append it with text validation
# If it's already a base Section (unlikely), just append it
else:
# Ensure text is always a string
processed_section = Section(
text=section.text if section.text is not None else "",
link=section.link,
image_file_name=section.image_file_name,
)
processed_sections.append(processed_section)
processed_sections.append(section)
# Create IndexingDocument with original sections and processed_sections
indexed_document = IndexingDocument(

View File

@@ -1,19 +1,10 @@
import io
from typing import cast
from PIL import Image
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.background.celery.tasks.beat_schedule import (
CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT,
)
from onyx.configs.constants import CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
from onyx.configs.constants import ONYX_CLOUD_REDIS_RUNTIME
from onyx.configs.constants import ONYX_CLOUD_TENANT_ID
from onyx.configs.constants import ONYX_EMAILABLE_LOGO_MAX_DIM
from onyx.db.engine import get_session_with_shared_schema
from onyx.file_store.file_store import PostgresBackedFileStore
from onyx.redis.redis_pool import get_redis_replica_client
from onyx.utils.file import FileWithMimeType
from onyx.utils.file import OnyxStaticFileManager
from onyx.utils.variable_functionality import (
@@ -96,72 +87,3 @@ class OnyxRuntime:
)
return OnyxRuntime._get_with_static_fallback(db_filename, STATIC_FILENAME)
@staticmethod
def get_beat_multiplier() -> float:
"""the beat multiplier is used to scale up or down the frequency of certain beat
tasks in the cloud. It has a significant effect on load and is useful to adjust
in real time."""
beat_multiplier: float = CLOUD_BEAT_MULTIPLIER_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
beat_multiplier_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:beat_multiplier")
if beat_multiplier_raw is not None:
try:
beat_multiplier_bytes = cast(bytes, beat_multiplier_raw)
beat_multiplier = float(beat_multiplier_bytes.decode())
except ValueError:
pass
if beat_multiplier <= 0.0:
return 1.0
return beat_multiplier
@staticmethod
def get_doc_permission_sync_multiplier() -> float:
"""Permission syncs are a significant source of load / queueing in the cloud."""
value: float = CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
value_raw = r.get(f"{ONYX_CLOUD_REDIS_RUNTIME}:doc_permission_sync_multiplier")
if value_raw is not None:
try:
value_bytes = cast(bytes, value_raw)
value = float(value_bytes.decode())
except ValueError:
pass
if value <= 0.0:
return 1.0
return value
@staticmethod
def get_build_fence_lookup_table_interval() -> int:
"""We maintain an active fence table to make lookups of existing fences efficient.
However, reconstructing the table is expensive, so adjusting it in realtime is useful.
"""
interval: int = CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
r = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
interval_raw = r.get(
f"{ONYX_CLOUD_REDIS_RUNTIME}:build_fence_lookup_table_interval"
)
if interval_raw is not None:
try:
interval_bytes = cast(bytes, interval_raw)
interval = int(interval_bytes.decode())
except ValueError:
pass
if interval <= 0.0:
return CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT
return interval

View File

@@ -12,6 +12,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
@@ -41,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from onyx.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
@@ -54,6 +58,7 @@ from onyx.utils.special_types import JSON_ro
logger = setup_logger()
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
@@ -352,13 +357,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
)
yield from yield_search_responses(
query=query,
# give back the merged sections to prevent duplicate docs from appearing in the UI
get_retrieved_sections=lambda: search_pipeline.merged_retrieved_sections,
get_final_context_sections=lambda: search_pipeline.final_context_sections,
search_query_info=search_query_info,
get_section_relevance=lambda: search_pipeline.section_relevance,
search_tool=self,
query,
lambda: search_pipeline.retrieved_sections,
lambda: search_pipeline.reranked_sections,
lambda: search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
@@ -400,6 +405,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
def yield_search_responses(
query: str,
get_retrieved_sections: Callable[[], list[InferenceSection]],
get_reranked_sections: Callable[[], list[InferenceSection]],
get_final_context_sections: Callable[[], list[InferenceSection]],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
@@ -417,6 +423,16 @@ def yield_search_responses(
),
)
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
context_from_inference_section(section)
for section in get_reranked_sections()
]
),
)
section_relevance = get_section_relevance()
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,

View File

@@ -1,4 +1,5 @@
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceSection
from onyx.prompts.prompt_utils import clean_up_source
@@ -31,23 +32,10 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
return doc_dict
def section_to_llm_doc(section: InferenceSection) -> LlmDoc:
possible_link_chunks = [section.center_chunk] + section.chunks
link: str | None = None
for chunk in possible_link_chunks:
if chunk.source_links:
link = list(chunk.source_links.values())[0]
break
return LlmDoc(
document_id=section.center_chunk.document_id,
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
return OnyxContext(
content=section.combined_content,
source_type=section.center_chunk.source_type,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
metadata=section.center_chunk.metadata,
updated_at=section.center_chunk.updated_at,
blurb=section.center_chunk.blurb,
link=link,
source_links=section.center_chunk.source_links,
match_highlights=section.center_chunk.match_highlights,
)

View File

@@ -78,19 +78,19 @@ def generate_dummy_chunk(
for i in range(number_of_document_sets):
document_set_names.append(f"Document Set {i}")
user_emails: list[str | None] = []
user_groups: list[str] = []
external_user_emails: list[str] = []
external_user_group_ids: list[str] = []
user_emails: set[str | None] = set()
user_groups: set[str] = set()
external_user_emails: set[str] = set()
external_user_group_ids: set[str] = set()
for i in range(number_of_acl_entries):
user_emails.append(f"user_{i}@example.com")
user_groups.append(f"group_{i}")
external_user_emails.append(f"external_user_{i}@example.com")
external_user_group_ids.append(f"external_group_{i}")
user_emails.add(f"user_{i}@example.com")
user_groups.add(f"group_{i}")
external_user_emails.add(f"external_user_{i}@example.com")
external_user_group_ids.add(f"external_group_{i}")
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=DocumentAccess.build(
access=DocumentAccess(
user_emails=user_emails,
user_groups=user_groups,
external_user_emails=external_user_emails,

View File

@@ -1,77 +0,0 @@
import os
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import BlobType
from onyx.connectors.blob.connector import BlobStorageConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import ACCEPTED_DOCUMENT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_IMAGE_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS
from onyx.file_processing.extract_file_text import get_file_ext
@pytest.fixture
def blob_connector(request: pytest.FixtureRequest) -> BlobStorageConnector:
connector = BlobStorageConnector(
bucket_type=BlobType.S3, bucket_name="onyx-connector-tests"
)
connector.load_credentials(
{
"aws_access_key_id": os.environ["AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS"],
"aws_secret_access_key": os.environ[
"AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS"
],
}
)
return connector
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
def test_blob_s3_connector(
mock_get_api_key: MagicMock, blob_connector: BlobStorageConnector
) -> None:
"""
Plain and document file types should be fully indexed.
Multimedia and unknown file types will be indexed by title only with one empty section.
This is intentional in order to allow searching by just the title even if we can't
index the file content.
"""
all_docs: list[Document] = []
document_batches = blob_connector.load_from_state()
for doc_batch in document_batches:
for doc in doc_batch:
all_docs.append(doc)
#
assert len(all_docs) == 19
for doc in all_docs:
section = doc.sections[0]
assert isinstance(section, TextSection)
file_extension = get_file_ext(doc.semantic_identifier)
if file_extension in ACCEPTED_PLAIN_TEXT_FILE_EXTENSIONS:
assert len(section.text) > 0
continue
if file_extension in ACCEPTED_DOCUMENT_FILE_EXTENSIONS:
assert len(section.text) > 0
continue
if file_extension in ACCEPTED_IMAGE_FILE_EXTENSIONS:
assert len(section.text) == 0
continue
# unknown extension
assert len(section.text) == 0

View File

@@ -58,16 +58,6 @@ SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
EXTERNAL_SHARED_FOLDER_URL = (
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
)
EXTERNAL_SHARED_DOCS_IN_FOLDER = [
"https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY"
]
EXTERNAL_SHARED_DOC_SINGLETON = (
"https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA"
)
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
ADMIN_EMAIL = "admin@onyx-test.com"

View File

@@ -1,7 +1,6 @@
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
@@ -10,15 +9,6 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOC_SINGLETON,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOCS_IN_FOLDER,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_FOLDER_URL,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
@@ -110,8 +100,7 @@ def test_include_shared_drives_only_with_size_threshold(
retrieved_docs = load_all_docs(connector)
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 52
assert len(retrieved_docs) == 50
@patch(
@@ -148,8 +137,7 @@ def test_include_shared_drives_only(
+ SECTIONS_FILE_IDS
)
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 53
assert len(retrieved_docs) == 51
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
@@ -306,64 +294,6 @@ def test_folders_only(
)
def test_shared_folder_owned_by_external_user(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_folder_owned_by_external_user")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=False,
include_files_shared_with_me=False,
shared_drive_urls=None,
shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER
assert len(retrieved_docs) == len(expected_docs) # 1 for now
assert expected_docs[0] in retrieved_docs[0].id
def test_shared_with_me(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_with_me")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=True,
include_files_shared_with_me=True,
shared_drive_urls=None,
shared_folder_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
print(retrieved_docs)
expected_file_ids = (
ADMIN_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
+ TEST_USER_1_FILE_IDS
+ TEST_USER_2_FILE_IDS
+ TEST_USER_3_FILE_IDS
)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs}
for id in retrieved_ids:
print(id)
assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids
assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,

View File

@@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 60
MAX_DELAY = 45
GENERAL_HEADERS = {"Content-Type": "application/json"}

View File

@@ -5,7 +5,6 @@ import requests
from requests.models import Response
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
@@ -98,24 +97,17 @@ class ChatSessionManager:
for data in response_data:
if "rephrased_query" in data:
analyzed.rephrased_query = data["rephrased_query"]
if "tool_name" in data:
elif "tool_name" in data:
analyzed.tool_name = data["tool_name"]
analyzed.tool_result = (
data.get("tool_result")
if analyzed.tool_name == "run_search"
else None
)
if "relevance_summaries" in data:
elif "relevance_summaries" in data:
analyzed.relevance_summaries = data["relevance_summaries"]
if "answer_piece" in data and data["answer_piece"]:
elif "answer_piece" in data and data["answer_piece"]:
analyzed.full_message += data["answer_piece"]
if "top_documents" in data:
assert (
analyzed.top_documents is None
), "top_documents should only be set once"
analyzed.top_documents = [
SavedSearchDoc(**doc) for doc in data["top_documents"]
]
return analyzed

View File

@@ -10,7 +10,6 @@ from pydantic import Field
from onyx.auth.schemas import UserRole
from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.models import SavedSearchDoc
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import IndexAttemptSnapshot
@@ -158,7 +157,7 @@ class StreamedResponse(BaseModel):
full_message: str = ""
rephrased_query: str | None = None
tool_name: str | None = None
top_documents: list[SavedSearchDoc] | None = None
top_documents: list[dict[str, Any]] | None = None
relevance_summaries: list[dict[str, Any]] | None = None
tool_result: Any | None = None
user: str | None = None

View File

@@ -5,7 +5,6 @@ This file contains tests for the following:
- updates the document sets and user groups to remove the connector
- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected
"""
import os
from uuid import uuid4
from sqlalchemy.orm import Session
@@ -33,13 +32,6 @@ from tests.integration.common_utils.vespa import vespa_fixture
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_group_1: DATestUserGroup
user_group_2: DATestUserGroup
is_ee = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
@@ -86,17 +78,16 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
print("Document sets created and synced")
if is_ee:
# create user groups
user_group_1 = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2 = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
# create user groups
user_group_1: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
# inject a finished index attempt and index attempt error (exercises foreign key errors)
with Session(get_sqlalchemy_engine()) as db_session:
@@ -156,13 +147,12 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
)
# Update local records to match the database for later comparison
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
doc_set_1.cc_pair_ids = []
doc_set_2.cc_pair_ids = [cc_pair_2.id]
cc_pair_1.groups = []
if is_ee:
cc_pair_2.groups = [user_group_2.id]
else:
cc_pair_2.groups = []
cc_pair_2.groups = [user_group_2.id]
CCPairManager.wait_for_deletion_completion(
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
@@ -178,15 +168,11 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
verify_deleted=True,
)
cc_pair_2_group_name_expected = []
if is_ee:
cc_pair_2_group_name_expected = [user_group_2.name]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[doc_set_2.name],
group_names=cc_pair_2_group_name_expected,
group_names=[user_group_2.name],
doc_creating_user=admin_user,
verify_deleted=False,
)
@@ -207,19 +193,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
if is_ee:
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
def test_connector_deletion_for_overlapping_connectors(
@@ -228,13 +210,6 @@ def test_connector_deletion_for_overlapping_connectors(
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
"""
user_group_1: DATestUserGroup
user_group_2: DATestUserGroup
is_ee = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
@@ -306,48 +281,47 @@ def test_connector_deletion_for_overlapping_connectors(
doc_creating_user=admin_user,
)
if is_ee:
# create a user group and attach it to connector 1
user_group_1 = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
# create a user group and attach it to connector 1
user_group_1: DATestUserGroup = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
print("User group 1 created and synced")
print("User group 1 created and synced")
# create a user group and attach it to connector 2
user_group_2 = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
# create a user group and attach it to connector 2
user_group_2: DATestUserGroup = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
print("User group 2 created and synced")
print("User group 2 created and synced")
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# delete connector 1
CCPairManager.pause_cc_pair(
@@ -380,15 +354,11 @@ def test_connector_deletion_for_overlapping_connectors(
# verify the document is not in any document sets
# verify the document is only in user group 2
group_names_expected = []
if is_ee:
group_names_expected = [user_group_2.name]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[],
group_names=group_names_expected,
group_names=[user_group_2.name],
doc_creating_user=admin_user,
verify_deleted=False,
)

View File

@@ -16,7 +16,10 @@ from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
def test_send_message_simple_with_history(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connectors
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
@@ -50,13 +53,13 @@ def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -
response_json = response.json()
# Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
# assert that the metadata is correct
for doc in cc_pair_1.documents:
found_doc = next(
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
None,
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
)
assert found_doc
assert found_doc["metadata"]["document_id"] == doc.id

View File

@@ -1,42 +0,0 @@
from collections.abc import Callable
import pytest
from onyx.configs.constants import DocumentSource
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import SimpleTestDocument
DocumentBuilderType = Callable[[list[str]], list[SimpleTestDocument]]
@pytest.fixture
def document_builder(admin_user: DATestUser) -> DocumentBuilderType:
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
# create connector
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
def _document_builder(contents: list[str]) -> list[SimpleTestDocument]:
# seed documents
docs: list[SimpleTestDocument] = [
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content=content,
api_key=api_key,
)
for content in contents
]
return docs
return _document_builder

View File

@@ -5,11 +5,12 @@ import pytest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.tests.streaming_endpoints.conftest import DocumentBuilderType
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
def test_send_message_simple_with_history(reset: None) -> None:
admin_user: DATestUser = UserManager.create(name="admin_user")
LLMProviderManager.create(user_performing_action=admin_user)
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
@@ -23,44 +24,6 @@ def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -
assert len(response.full_message) > 0
def test_send_message__basic_searches(
reset: None, admin_user: DATestUser, document_builder: DocumentBuilderType
) -> None:
MESSAGE = "run a search for 'test'"
SHORT_DOC_CONTENT = "test"
LONG_DOC_CONTENT = "blah blah blah blah" * 100
LLMProviderManager.create(user_performing_action=admin_user)
short_doc = document_builder([SHORT_DOC_CONTENT])[0]
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message=MESSAGE,
user_performing_action=admin_user,
)
assert response.top_documents is not None
assert len(response.top_documents) == 1
assert response.top_documents[0].document_id == short_doc.id
# make sure this doc is really long so that it will be split into multiple chunks
long_doc = document_builder([LONG_DOC_CONTENT])[0]
# new chat session for simplicity
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message=MESSAGE,
user_performing_action=admin_user,
)
assert response.top_documents is not None
assert len(response.top_documents) == 2
# short doc should be more relevant and thus first
assert response.top_documents[0].document_id == short_doc.id
assert response.top_documents[1].document_id == long_doc.id
@pytest.mark.skip(
reason="enable for autorun when we have a testing environment with semantically useful data"
)

View File

@@ -9,6 +9,8 @@ from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import DocumentSource
@@ -17,6 +19,7 @@ from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
@@ -117,7 +120,24 @@ def mock_search_results(
@pytest.fixture
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
return OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in mock_inference_sections
]
)
@pytest.fixture
def mock_search_tool(
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
) -> MagicMock:
mock_tool = MagicMock(spec=SearchTool)
mock_tool.name = "search"
mock_tool.build_tool_message_content.return_value = "search_response"
@@ -126,6 +146,7 @@ def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
json.loads(doc.model_dump_json()) for doc in mock_search_results
]
mock_tool.run.return_value = [
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
]
mock_tool.tool_definition.return_value = {

View File

@@ -19,6 +19,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
@@ -32,6 +33,7 @@ from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -139,6 +141,7 @@ def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
def test_answer_with_search_call(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock,
force_use_tool: ForceUseTool,
expected_tool_args: dict,
@@ -194,21 +197,25 @@ def test_answer_with_search_call(
tool_name="search", tool_args=expected_tool_args
)
assert output[1] == ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=mock_contexts,
)
assert output[2] == ToolResponse(
id="final_context_documents",
response=mock_search_results,
)
assert output[2] == ToolCallFinalResult(
assert output[3] == ToolCallFinalResult(
tool_name="search",
tool_args=expected_tool_args,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[4] == expected_citation
assert output[5] == OnyxAnswerPiece(
assert output[5] == expected_citation
assert output[6] == OnyxAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "
@@ -261,6 +268,7 @@ def test_answer_with_search_call(
def test_answer_with_search_no_tool_calling(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock,
) -> None:
answer_instance.graph_config.tooling.tools = [mock_search_tool]
@@ -280,26 +288,30 @@ def test_answer_with_search_no_tool_calling(
output = list(answer_instance.processed_streamed_output)
# Assertions
assert len(output) == 7
assert len(output) == 8
assert output[0] == ToolCallKickoff(
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
)
assert output[1] == ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=mock_contexts,
)
assert output[2] == ToolResponse(
id=FINAL_CONTEXT_DOCUMENTS_ID,
response=mock_search_results,
)
assert output[2] == ToolCallFinalResult(
assert output[3] == ToolCallFinalResult(
tool_name="search",
tool_args=DEFAULT_SEARCH_ARGS,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[4] == expected_citation
assert output[5] == OnyxAnswerPiece(
assert output[5] == expected_citation
assert output[6] == OnyxAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "

View File

@@ -79,7 +79,7 @@ def test_skip_gen_ai_answer_generation_flag(
for res in results:
print(res)
expected_count = 3 if skip_gen_ai_answer_generation else 4
expected_count = 4 if skip_gen_ai_answer_generation else 5
assert len(results) == expected_count
if not skip_gen_ai_answer_generation:
mock_llm.stream.assert_called_once()

View File

@@ -45,7 +45,7 @@ export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
className="mr-1 my-auto cursor-pointer"
onClick={() =>
router.push(
`/admin/actions/edit/${tool.id}?u=${Date.now()}`
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
)
}
/>

View File

@@ -281,7 +281,7 @@ export default function AddConnector({
return (
<Formik
initialValues={{
...createConnectorInitialValues(connector, currentCredential),
...createConnectorInitialValues(connector),
...Object.fromEntries(
connectorConfigs[connector].advanced_values.map((field) => [
field.name,

View File

@@ -1384,7 +1384,6 @@ export function ChatPage({
if (!packet) {
continue;
}
console.log("Packet:", JSON.stringify(packet));
if (!initialFetchDetails) {
if (!Object.hasOwn(packet, "user_message_id")) {
@@ -1730,7 +1729,6 @@ export function ChatPage({
}
}
} catch (e: any) {
console.log("Error:", e);
const errorMsg = e.message;
upsertToCompleteMessageMap({
messages: [
@@ -1758,13 +1756,11 @@ export function ChatPage({
completeMessageMapOverride: currentMessageMap(completeMessageDetail),
});
}
console.log("Finished streaming");
setAgenticGenerating(false);
resetRegenerationState(currentSessionId());
updateChatState("input");
if (isNewSession) {
console.log("Setting up new session");
if (finalMessage) {
setSelectedMessageForDocDisplay(finalMessage.message_id);
}

View File

@@ -34,12 +34,11 @@
/* -------------------------------------------------------
* 2. Keep special, custom, or near-duplicate background
* ------------------------------------------------------- */
--background: #fefcfa; /* slightly off-white */
--background-50: #fffdfb; /* a little lighter than background but not quite white */
--background: #fefcfa; /* slightly off-white, keep it */
--input-background: #fefcfa;
--input-border: #f1eee8;
--text-text: #f4f2ed;
--background-dark: #141414;
--background-dark: #e9e6e0;
--new-background: #ebe7de;
--new-background-light: #d9d1c0;
--background-chatbar: #f5f3ee;
@@ -235,7 +234,6 @@
--text-text: #1d1d1d;
--background-dark: #252525;
--background-50: #252525;
/* --new-background: #fff; */
--new-background: #2c2c2c;

View File

@@ -181,7 +181,7 @@ const SignedUpUserTable = ({
: "All Roles"}
</SelectValue>
</SelectTrigger>
<SelectContent className="bg-background-50">
<SelectContent className="bg-background">
{Object.entries(USER_ROLE_LABELS)
.filter(([role]) => role !== UserRole.EXT_PERM_USER)
.map(([role, label]) => (

View File

@@ -26,13 +26,7 @@ export const buildDocumentSummaryDisplay = (
matchHighlights: string[],
blurb: string
) => {
// if there are no match highlights, or if it's really short, just use the blurb
// this is to prevent the UI from showing something like `...` for the summary
const MIN_MATCH_HIGHLIGHT_LENGTH = 5;
if (
!matchHighlights ||
matchHighlights.length <= MIN_MATCH_HIGHLIGHT_LENGTH
) {
if (!matchHighlights || matchHighlights.length === 0) {
return blurb;
}

View File

@@ -1292,8 +1292,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
},
};
export function createConnectorInitialValues(
connector: ConfigurableSources,
currentCredential: Credential<any> | null = null
connector: ConfigurableSources
): Record<string, any> & AccessTypeGroupSelectorFormType {
const configuration = connectorConfigs[connector];
@@ -1308,16 +1307,7 @@ export function createConnectorInitialValues(
} else if (field.type === "list") {
acc[field.name] = field.default || [];
} else if (field.type === "checkbox") {
// Special case for include_files_shared_with_me when using service account
if (
field.name === "include_files_shared_with_me" &&
currentCredential &&
!currentCredential.credential_json?.google_tokens
) {
acc[field.name] = true;
} else {
acc[field.name] = field.default || false;
}
acc[field.name] = field.default || false;
} else if (field.default !== undefined) {
acc[field.name] = field.default;
}

View File

@@ -108,7 +108,6 @@ module.exports = {
"accent-background": "var(--accent-background)",
"accent-background-hovered": "var(--accent-background-hovered)",
"accent-background-selected": "var(--accent-background-selected)",
"background-50": "var(--background-50)",
"background-dark": "var(--off-white)",
"background-100": "var(--neutral-100-border-light)",
"background-125": "var(--neutral-125)",