mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
1 Commits
fix_textse
...
bugfix/exp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
54e61611c5 |
@@ -93,12 +93,12 @@ def _get_access_for_documents(
|
||||
)
|
||||
|
||||
# To avoid collisions of group namings between connectors, they need to be prefixed
|
||||
access_map[document_id] = DocumentAccess.build(
|
||||
user_emails=list(non_ee_access.user_emails),
|
||||
user_groups=user_group_info.get(document_id, []),
|
||||
access_map[document_id] = DocumentAccess(
|
||||
user_emails=non_ee_access.user_emails,
|
||||
user_groups=set(user_group_info.get(document_id, [])),
|
||||
is_public=is_public_anywhere,
|
||||
external_user_emails=list(ext_u_emails),
|
||||
external_user_group_ids=list(ext_u_groups),
|
||||
external_user_emails=ext_u_emails,
|
||||
external_user_group_ids=ext_u_groups,
|
||||
)
|
||||
return access_map
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
@@ -31,6 +32,8 @@ def gather_stream_for_answer_api(
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
elif isinstance(packet, OnyxContexts):
|
||||
response.contexts = packet
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
@@ -14,6 +14,7 @@ from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
)
|
||||
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
|
||||
from ee.onyx.server.query_and_chat.models import SimpleDoc
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
@@ -55,6 +56,25 @@ logger = setup_logger()
|
||||
router = APIRouter(prefix="/chat")
|
||||
|
||||
|
||||
def _translate_doc_response_to_simple_doc(
|
||||
doc_response: QADocsResponse,
|
||||
) -> list[SimpleDoc]:
|
||||
return [
|
||||
SimpleDoc(
|
||||
id=doc.document_id,
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
link=doc.link,
|
||||
blurb=doc.blurb,
|
||||
match_highlights=[
|
||||
highlight for highlight in doc.match_highlights if highlight
|
||||
],
|
||||
source_type=doc.source_type,
|
||||
metadata=doc.metadata,
|
||||
)
|
||||
for doc in doc_response.top_documents
|
||||
]
|
||||
|
||||
|
||||
def _get_final_context_doc_indices(
|
||||
final_context_docs: list[LlmDoc] | None,
|
||||
top_docs: list[SavedSearchDoc] | None,
|
||||
@@ -91,6 +111,9 @@ def _convert_packet_stream_to_response(
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.top_documents = packet.top_documents
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic import model_validator
|
||||
|
||||
from ee.onyx.server.manage.models import StandardAnswer
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
@@ -163,6 +164,8 @@ class ChatBasicResponse(BaseModel):
|
||||
cited_documents: dict[int, str] | None = None
|
||||
|
||||
# FOR BACKWARDS COMPATIBILITY
|
||||
# TODO: deprecate both of these
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
@@ -217,3 +220,4 @@ class OneShotQAResponse(BaseModel):
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
contexts: OnyxContexts | None = None
|
||||
|
||||
@@ -18,7 +18,7 @@ def _get_access_for_document(
|
||||
document_id=document_id,
|
||||
)
|
||||
|
||||
doc_access = DocumentAccess.build(
|
||||
return DocumentAccess.build(
|
||||
user_emails=info[1] if info and info[1] else [],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
@@ -26,8 +26,6 @@ def _get_access_for_document(
|
||||
is_public=info[2] if info else False,
|
||||
)
|
||||
|
||||
return doc_access
|
||||
|
||||
|
||||
def get_access_for_document(
|
||||
document_id: str,
|
||||
@@ -40,12 +38,12 @@ def get_access_for_document(
|
||||
|
||||
|
||||
def get_null_document_access() -> DocumentAccess:
|
||||
return DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
return DocumentAccess(
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=False,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
|
||||
|
||||
@@ -58,18 +56,18 @@ def _get_access_for_documents(
|
||||
document_ids=document_ids,
|
||||
)
|
||||
doc_access = {
|
||||
document_id: DocumentAccess.build(
|
||||
user_emails=[email for email in user_emails if email],
|
||||
document_id: DocumentAccess(
|
||||
user_emails=set([email for email in user_emails if email]),
|
||||
# MIT version will wipe all groups and external groups on update
|
||||
user_groups=[],
|
||||
user_groups=set(),
|
||||
is_public=is_public,
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
)
|
||||
for document_id, user_emails, is_public in document_access_info
|
||||
}
|
||||
|
||||
# Sometimes the document has not been indexed by the indexing job yet, in those cases
|
||||
# Sometimes the document has not be indexed by the indexing job yet, in those cases
|
||||
# the document does not exist and so we use least permissive. Specifically the EE version
|
||||
# checks the MIT version permissions and creates a superset. This ensures that this flow
|
||||
# does not fail even if the Document has not yet been indexed.
|
||||
|
||||
@@ -56,45 +56,33 @@ class DocExternalAccess:
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True, init=False)
|
||||
@dataclass(frozen=True)
|
||||
class DocumentAccess(ExternalAccess):
|
||||
# User emails for Onyx users, None indicates admin
|
||||
user_emails: set[str | None]
|
||||
|
||||
# Names of user groups associated with this document
|
||||
user_groups: set[str]
|
||||
|
||||
external_user_emails: set[str]
|
||||
external_user_group_ids: set[str]
|
||||
is_public: bool
|
||||
|
||||
def __init__(self) -> None:
|
||||
raise TypeError(
|
||||
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
|
||||
)
|
||||
|
||||
def to_acl(self) -> set[str]:
|
||||
# the acl's emitted by this function are prefixed by type
|
||||
# to get the native objects, access the member variables directly
|
||||
|
||||
acl_set: set[str] = set()
|
||||
for user_email in self.user_emails:
|
||||
if user_email:
|
||||
acl_set.add(prefix_user_email(user_email))
|
||||
|
||||
for group_name in self.user_groups:
|
||||
acl_set.add(prefix_user_group(group_name))
|
||||
|
||||
for external_user_email in self.external_user_emails:
|
||||
acl_set.add(prefix_user_email(external_user_email))
|
||||
|
||||
for external_group_id in self.external_user_group_ids:
|
||||
acl_set.add(prefix_external_group(external_group_id))
|
||||
|
||||
if self.is_public:
|
||||
acl_set.add(PUBLIC_DOC_PAT)
|
||||
|
||||
return acl_set
|
||||
return set(
|
||||
[
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.user_emails
|
||||
if user_email
|
||||
]
|
||||
+ [prefix_user_group(group_name) for group_name in self.user_groups]
|
||||
+ [
|
||||
prefix_user_email(user_email)
|
||||
for user_email in self.external_user_emails
|
||||
]
|
||||
+ [
|
||||
# The group names are already prefixed by the source type
|
||||
# This adds an additional prefix of "external_group:"
|
||||
prefix_external_group(group_name)
|
||||
for group_name in self.external_user_group_ids
|
||||
]
|
||||
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def build(
|
||||
@@ -105,32 +93,29 @@ class DocumentAccess(ExternalAccess):
|
||||
external_user_group_ids: list[str],
|
||||
is_public: bool,
|
||||
) -> "DocumentAccess":
|
||||
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
|
||||
|
||||
obj = object.__new__(cls)
|
||||
object.__setattr__(
|
||||
obj, "user_emails", {user_email for user_email in user_emails if user_email}
|
||||
return cls(
|
||||
external_user_emails={
|
||||
prefix_user_email(external_email)
|
||||
for external_email in external_user_emails
|
||||
},
|
||||
external_user_group_ids={
|
||||
prefix_external_group(external_group_id)
|
||||
for external_group_id in external_user_group_ids
|
||||
},
|
||||
user_emails={
|
||||
prefix_user_email(user_email)
|
||||
for user_email in user_emails
|
||||
if user_email
|
||||
},
|
||||
user_groups=set(user_groups),
|
||||
is_public=is_public,
|
||||
)
|
||||
object.__setattr__(obj, "user_groups", set(user_groups))
|
||||
object.__setattr__(
|
||||
obj,
|
||||
"external_user_emails",
|
||||
{external_email for external_email in external_user_emails},
|
||||
)
|
||||
object.__setattr__(
|
||||
obj,
|
||||
"external_user_group_ids",
|
||||
{external_group_id for external_group_id in external_user_group_ids},
|
||||
)
|
||||
object.__setattr__(obj, "is_public", is_public)
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
default_public_access = DocumentAccess.build(
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
default_public_access = DocumentAccess(
|
||||
external_user_emails=set(),
|
||||
external_user_group_ids=set(),
|
||||
user_emails=set(),
|
||||
user_groups=set(),
|
||||
is_public=True,
|
||||
)
|
||||
|
||||
@@ -7,6 +7,7 @@ from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
|
||||
from onyx.chat.stream_processing.answer_response_handler import (
|
||||
@@ -23,7 +24,7 @@ def process_llm_stream(
|
||||
should_stream_answer: bool,
|
||||
writer: StreamWriter,
|
||||
final_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[LlmDoc] | None = None,
|
||||
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
|
||||
) -> AIMessageChunk:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
|
||||
|
||||
@@ -156,6 +156,7 @@ def generate_initial_answer(
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -183,6 +183,7 @@ def generate_validate_refined_answer(
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -57,6 +57,7 @@ def format_results(
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
get_retrieved_sections=lambda: reranked_documents,
|
||||
get_reranked_sections=lambda: state.retrieved_documents,
|
||||
get_final_context_sections=lambda: reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
|
||||
@@ -13,7 +13,9 @@ from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
@@ -57,7 +59,9 @@ def basic_use_tool_response(
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_results:
|
||||
initial_search_results.append(section_to_llm_doc(section))
|
||||
initial_search_results.append(
|
||||
context_from_inference_section(section)
|
||||
)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
|
||||
@@ -389,8 +389,6 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
credential_id_to_delete: int | None = None
|
||||
connector_id_to_delete: int | None = None
|
||||
if not cc_pair:
|
||||
task_logger.warning(
|
||||
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
|
||||
@@ -445,35 +443,26 @@ def monitor_connector_deletion_taskset(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Store IDs before potentially expiring cc_pair
|
||||
connector_id_to_delete = cc_pair.connector_id
|
||||
credential_id_to_delete = cc_pair.credential_id
|
||||
|
||||
# Explicitly delete document by connector credential pair records before deleting the connector
|
||||
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
|
||||
delete_all_documents_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
credential_id=credential_id_to_delete,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# Flush to ensure document deletion happens before connector deletion
|
||||
db_session.flush()
|
||||
|
||||
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
|
||||
# related to the deleted DocumentByConnectorCredentialPair during commit
|
||||
db_session.expire(cc_pair)
|
||||
|
||||
# finally, delete the cc-pair
|
||||
delete_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
credential_id=credential_id_to_delete,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
# if there are no credentials left, delete the connector
|
||||
connector = fetch_connector_by_id(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id_to_delete,
|
||||
connector_id=cc_pair.connector_id,
|
||||
)
|
||||
if not connector or not len(connector.credentials):
|
||||
task_logger.info(
|
||||
@@ -506,15 +495,15 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
task_logger.exception(
|
||||
f"Connector deletion exceptioned: "
|
||||
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
|
||||
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
|
||||
)
|
||||
raise e
|
||||
|
||||
task_logger.info(
|
||||
f"Connector deletion succeeded: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"connector={connector_id_to_delete} "
|
||||
f"credential={credential_id_to_delete} "
|
||||
f"connector={cc_pair.connector_id} "
|
||||
f"credential={cc_pair.credential_id} "
|
||||
f"docs_deleted={fence_data.num_tasks}"
|
||||
)
|
||||
|
||||
@@ -564,7 +553,7 @@ def validate_connector_deletion_fences(
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_upsert_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
@@ -651,7 +640,7 @@ def validate_connector_deletion_fence(
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_upsert_tasks:
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
@@ -194,6 +194,17 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class OnyxContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class OnyxContexts(BaseModel):
|
||||
contexts: list[OnyxContext]
|
||||
|
||||
|
||||
class OnyxAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
@@ -259,6 +270,7 @@ class PersonaOverrideConfig(BaseModel):
|
||||
AnswerQuestionPossibleReturn = (
|
||||
OnyxAnswerPiece
|
||||
| CitationInfo
|
||||
| OnyxContexts
|
||||
| FileChatDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
|
||||
@@ -29,6 +29,7 @@ from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import MessageResponseIDInfo
|
||||
from onyx.chat.models import MessageSpecificCitations
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
@@ -130,6 +131,7 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -298,6 +300,7 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| OnyxContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -916,6 +919,8 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(OnyxContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason == StreamStopReason.FINISHED:
|
||||
|
||||
@@ -301,10 +301,6 @@ def prune_sections(
|
||||
|
||||
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
assert (
|
||||
len(set([chunk.document_id for chunk in chunks])) == 1
|
||||
), "One distinct document must be passed into merge_doc_chunks"
|
||||
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Sequence
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
|
||||
|
||||
@@ -11,7 +12,7 @@ class DocumentIdOrderMapping(BaseModel):
|
||||
|
||||
|
||||
def map_document_id_order(
|
||||
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
|
||||
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
|
||||
) -> DocumentIdOrderMapping:
|
||||
order_mapping = {}
|
||||
current = 1 if one_indexed else 0
|
||||
|
||||
@@ -28,9 +28,7 @@ from onyx.connectors.google_drive.doc_conversion import (
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
get_all_files_in_my_drive_and_shared,
|
||||
)
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
@@ -88,18 +86,13 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
|
||||
|
||||
def _convert_single_file(
|
||||
creds: Any,
|
||||
primary_admin_email: str,
|
||||
allow_images: bool,
|
||||
size_threshold: int,
|
||||
retriever_email: str,
|
||||
file: dict[str, Any],
|
||||
) -> Document | ConnectorFailure | None:
|
||||
# We used to always get the user email from the file owners when available,
|
||||
# but this was causing issues with shared folders where the owner was not included in the service account
|
||||
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
|
||||
# wanting to retry with file owners and/or admin email at some point.
|
||||
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
|
||||
user_email = retriever_email
|
||||
# Only construct these services when needed
|
||||
user_drive_service = lazy_eval(
|
||||
lambda: get_drive_service(creds, user_email=user_email)
|
||||
@@ -457,11 +450,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
logger.info(f"Getting all files in my drive as '{user_email}'")
|
||||
|
||||
yield from add_retrieval_info(
|
||||
get_all_files_in_my_drive_and_shared(
|
||||
get_all_files_in_my_drive(
|
||||
service=drive_service,
|
||||
update_traversed_ids_func=self._update_traversed_parent_ids,
|
||||
is_slim=is_slim,
|
||||
include_shared_with_me=self.include_files_shared_with_me,
|
||||
start=curr_stage.completed_until if resuming else start,
|
||||
end=end,
|
||||
),
|
||||
@@ -924,28 +916,20 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
convert_func = partial(
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
self.allow_images,
|
||||
self.size_threshold,
|
||||
)
|
||||
# Fetch files in batches
|
||||
batches_complete = 0
|
||||
files_batch: list[RetrievedDriveFile] = []
|
||||
files_batch: list[GoogleDriveFileType] = []
|
||||
|
||||
def _yield_batch(
|
||||
files_batch: list[RetrievedDriveFile],
|
||||
files_batch: list[GoogleDriveFileType],
|
||||
) -> Iterator[Document | ConnectorFailure]:
|
||||
nonlocal batches_complete
|
||||
# Process the batch using run_functions_tuples_in_parallel
|
||||
func_with_args = [
|
||||
(
|
||||
convert_func,
|
||||
(
|
||||
file.user_email,
|
||||
file.drive_file,
|
||||
),
|
||||
)
|
||||
for file in files_batch
|
||||
]
|
||||
func_with_args = [(convert_func, (file,)) for file in files_batch]
|
||||
results = cast(
|
||||
list[Document | ConnectorFailure | None],
|
||||
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
|
||||
@@ -983,7 +967,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
)
|
||||
|
||||
continue
|
||||
files_batch.append(retrieved_file)
|
||||
files_batch.append(retrieved_file.drive_file)
|
||||
|
||||
if len(files_batch) < self.batch_size:
|
||||
continue
|
||||
|
||||
@@ -87,17 +87,35 @@ def _download_and_extract_sections_basic(
|
||||
mime_type = file["mimeType"]
|
||||
link = file.get("webViewLink", "")
|
||||
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
try:
|
||||
# skip images if not explicitly enabled
|
||||
if not allow_images and is_gdrive_image_mime_type(mime_type):
|
||||
return []
|
||||
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
@@ -106,100 +124,88 @@ def _download_and_extract_sections_basic(
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
return []
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {file_name}: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
|
||||
@@ -214,11 +214,10 @@ def get_files_in_shared_drive(
|
||||
yield file
|
||||
|
||||
|
||||
def get_all_files_in_my_drive_and_shared(
|
||||
def get_all_files_in_my_drive(
|
||||
service: GoogleDriveService,
|
||||
update_traversed_ids_func: Callable,
|
||||
is_slim: bool,
|
||||
include_shared_with_me: bool,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> Iterator[GoogleDriveFileType]:
|
||||
@@ -230,8 +229,7 @@ def get_all_files_in_my_drive_and_shared(
|
||||
# Get all folders being queried and add them to the traversed set
|
||||
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
|
||||
folder_query += " and trashed = false"
|
||||
if not include_shared_with_me:
|
||||
folder_query += " and 'me' in owners"
|
||||
folder_query += " and 'me' in owners"
|
||||
found_folders = False
|
||||
for file in execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
@@ -248,8 +246,7 @@ def get_all_files_in_my_drive_and_shared(
|
||||
# Then get the files
|
||||
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
|
||||
file_query += " and trashed = false"
|
||||
if not include_shared_with_me:
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += " and 'me' in owners"
|
||||
file_query += _generate_time_range_filter(start, end)
|
||||
yield from execute_paginated_retrieval(
|
||||
retrieval_function=service.files().list,
|
||||
|
||||
@@ -339,12 +339,6 @@ class SearchPipeline:
|
||||
self._retrieved_sections = self._get_sections()
|
||||
return self._retrieved_sections
|
||||
|
||||
@property
|
||||
def merged_retrieved_sections(self) -> list[InferenceSection]:
|
||||
"""Should be used to display in the UI in order to prevent displaying
|
||||
multiple sections for the same document as separate "documents"."""
|
||||
return _merge_sections(sections=self.retrieved_sections)
|
||||
|
||||
@property
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
@@ -421,10 +415,6 @@ class SearchPipeline:
|
||||
raise ValueError(
|
||||
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
|
||||
)
|
||||
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
|
||||
# since the property sets the generator. DO NOT REMOVE.
|
||||
_ = self.final_context_sections
|
||||
|
||||
self._section_relevance = next(
|
||||
cast(
|
||||
Iterator[list[SectionRelevancePiece]],
|
||||
|
||||
@@ -821,26 +821,30 @@ class VespaIndex(DocumentIndex):
|
||||
num_to_retrieve: int = NUM_RETURNED_HITS,
|
||||
offset: int = 0,
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
vespa_where_clauses = build_vespa_filters(filters, include_hidden=True)
|
||||
yql = (
|
||||
YQL_BASE.format(index_name=self.index_name)
|
||||
+ vespa_where_clauses
|
||||
+ '({grammar: "weakAnd"}userInput(@query) '
|
||||
# `({defaultIndex: "content_summary"}userInput(@query))` section is
|
||||
# needed for highlighting while the N-gram highlighting is broken /
|
||||
# not working as desired
|
||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
||||
vespa_where_clauses = build_vespa_filters(
|
||||
filters, include_hidden=True, remove_trailing_and=True
|
||||
)
|
||||
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
|
||||
|
||||
params: dict[str, str | int] = {
|
||||
"yql": yql,
|
||||
"query": query,
|
||||
"hits": num_to_retrieve,
|
||||
"offset": 0,
|
||||
"ranking.profile": "admin_search",
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
}
|
||||
|
||||
if len(query.strip()) > 0:
|
||||
yql += (
|
||||
' and ({grammar: "weakAnd"}userInput(@query) '
|
||||
# `({defaultIndex: "content_summary"}userInput(@query))` section is
|
||||
# needed for highlighting while the N-gram highlighting is broken /
|
||||
# not working as desired
|
||||
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
|
||||
)
|
||||
params["yql"] = yql
|
||||
params["query"] = query
|
||||
|
||||
return query_vespa(params)
|
||||
|
||||
# Retrieves chunk information for a document:
|
||||
|
||||
@@ -224,6 +224,27 @@ class Chunker:
|
||||
)
|
||||
chunks_list.append(new_chunk)
|
||||
|
||||
def _chunk_document(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Legacy method for backward compatibility.
|
||||
Calls _chunk_document_with_sections with document.sections.
|
||||
"""
|
||||
return self._chunk_document_with_sections(
|
||||
document,
|
||||
document.processed_sections,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
content_token_limit,
|
||||
)
|
||||
|
||||
def _chunk_document_with_sections(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
@@ -243,7 +264,7 @@ class Chunker:
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
# Get section text and other attributes
|
||||
section_text = clean_text(str(section.text or ""))
|
||||
section_text = clean_text(section.text or "")
|
||||
section_link_text = section.link or ""
|
||||
image_url = section.image_file_name
|
||||
|
||||
|
||||
@@ -439,7 +439,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
**document.dict(),
|
||||
processed_sections=[
|
||||
Section(
|
||||
text=section.text if isinstance(section, TextSection) else "",
|
||||
text=section.text if isinstance(section, TextSection) else None,
|
||||
link=section.link,
|
||||
image_file_name=section.image_file_name
|
||||
if isinstance(section, ImageSection)
|
||||
@@ -459,11 +459,11 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
for section in document.sections:
|
||||
# For ImageSection, process and create base Section with both text and image_file_name
|
||||
if isinstance(section, ImageSection):
|
||||
# Default section with image path preserved - ensure text is always a string
|
||||
# Default section with image path preserved
|
||||
processed_section = Section(
|
||||
link=section.link,
|
||||
image_file_name=section.image_file_name,
|
||||
text="", # Initialize with empty string
|
||||
text=None, # Will be populated if summarization succeeds
|
||||
)
|
||||
|
||||
# Try to get image summary
|
||||
@@ -506,21 +506,13 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
# For TextSection, create a base Section with text and link
|
||||
elif isinstance(section, TextSection):
|
||||
processed_section = Section(
|
||||
text=section.text or "", # Ensure text is always a string, not None
|
||||
link=section.link,
|
||||
image_file_name=None,
|
||||
text=section.text, link=section.link, image_file_name=None
|
||||
)
|
||||
processed_sections.append(processed_section)
|
||||
|
||||
# If it's already a base Section (unlikely), just append it with text validation
|
||||
# If it's already a base Section (unlikely), just append it
|
||||
else:
|
||||
# Ensure text is always a string
|
||||
processed_section = Section(
|
||||
text=section.text if section.text is not None else "",
|
||||
link=section.link,
|
||||
image_file_name=section.image_file_name,
|
||||
)
|
||||
processed_sections.append(processed_section)
|
||||
processed_sections.append(section)
|
||||
|
||||
# Create IndexingDocument with original sections and processed_sections
|
||||
indexed_document = IndexingDocument(
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ContextualPruningConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
@@ -41,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
build_next_prompt_for_search_like_tool,
|
||||
@@ -54,6 +58,7 @@ from onyx.utils.special_types import JSON_ro
|
||||
logger = setup_logger()
|
||||
|
||||
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
QUERY_FIELD = "query"
|
||||
@@ -352,13 +357,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
|
||||
)
|
||||
yield from yield_search_responses(
|
||||
query=query,
|
||||
# give back the merged sections to prevent duplicate docs from appearing in the UI
|
||||
get_retrieved_sections=lambda: search_pipeline.merged_retrieved_sections,
|
||||
get_final_context_sections=lambda: search_pipeline.final_context_sections,
|
||||
search_query_info=search_query_info,
|
||||
get_section_relevance=lambda: search_pipeline.section_relevance,
|
||||
search_tool=self,
|
||||
query,
|
||||
lambda: search_pipeline.retrieved_sections,
|
||||
lambda: search_pipeline.reranked_sections,
|
||||
lambda: search_pipeline.final_context_sections,
|
||||
search_query_info,
|
||||
lambda: search_pipeline.section_relevance,
|
||||
self,
|
||||
)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
@@ -400,6 +405,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
def yield_search_responses(
|
||||
query: str,
|
||||
get_retrieved_sections: Callable[[], list[InferenceSection]],
|
||||
get_reranked_sections: Callable[[], list[InferenceSection]],
|
||||
get_final_context_sections: Callable[[], list[InferenceSection]],
|
||||
search_query_info: SearchQueryInfo,
|
||||
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
|
||||
@@ -417,6 +423,16 @@ def yield_search_responses(
|
||||
),
|
||||
)
|
||||
|
||||
yield ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=OnyxContexts(
|
||||
contexts=[
|
||||
context_from_inference_section(section)
|
||||
for section in get_reranked_sections()
|
||||
]
|
||||
),
|
||||
)
|
||||
|
||||
section_relevance = get_section_relevance()
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.prompt_utils import clean_up_source
|
||||
|
||||
@@ -31,23 +32,10 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
|
||||
return doc_dict
|
||||
|
||||
|
||||
def section_to_llm_doc(section: InferenceSection) -> LlmDoc:
|
||||
possible_link_chunks = [section.center_chunk] + section.chunks
|
||||
link: str | None = None
|
||||
for chunk in possible_link_chunks:
|
||||
if chunk.source_links:
|
||||
link = list(chunk.source_links.values())[0]
|
||||
break
|
||||
|
||||
return LlmDoc(
|
||||
document_id=section.center_chunk.document_id,
|
||||
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
|
||||
return OnyxContext(
|
||||
content=section.combined_content,
|
||||
source_type=section.center_chunk.source_type,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
metadata=section.center_chunk.metadata,
|
||||
updated_at=section.center_chunk.updated_at,
|
||||
blurb=section.center_chunk.blurb,
|
||||
link=link,
|
||||
source_links=section.center_chunk.source_links,
|
||||
match_highlights=section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
@@ -78,19 +78,19 @@ def generate_dummy_chunk(
|
||||
for i in range(number_of_document_sets):
|
||||
document_set_names.append(f"Document Set {i}")
|
||||
|
||||
user_emails: list[str | None] = []
|
||||
user_groups: list[str] = []
|
||||
external_user_emails: list[str] = []
|
||||
external_user_group_ids: list[str] = []
|
||||
user_emails: set[str | None] = set()
|
||||
user_groups: set[str] = set()
|
||||
external_user_emails: set[str] = set()
|
||||
external_user_group_ids: set[str] = set()
|
||||
for i in range(number_of_acl_entries):
|
||||
user_emails.append(f"user_{i}@example.com")
|
||||
user_groups.append(f"group_{i}")
|
||||
external_user_emails.append(f"external_user_{i}@example.com")
|
||||
external_user_group_ids.append(f"external_group_{i}")
|
||||
user_emails.add(f"user_{i}@example.com")
|
||||
user_groups.add(f"group_{i}")
|
||||
external_user_emails.add(f"external_user_{i}@example.com")
|
||||
external_user_group_ids.add(f"external_group_{i}")
|
||||
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=DocumentAccess.build(
|
||||
access=DocumentAccess(
|
||||
user_emails=user_emails,
|
||||
user_groups=user_groups,
|
||||
external_user_emails=external_user_emails,
|
||||
|
||||
@@ -58,16 +58,6 @@ SECTIONS_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
|
||||
)
|
||||
|
||||
EXTERNAL_SHARED_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
|
||||
)
|
||||
EXTERNAL_SHARED_DOCS_IN_FOLDER = [
|
||||
"https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY"
|
||||
]
|
||||
EXTERNAL_SHARED_DOC_SINGLETON = (
|
||||
"https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA"
|
||||
)
|
||||
|
||||
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
|
||||
|
||||
ADMIN_EMAIL = "admin@onyx-test.com"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
@@ -10,15 +9,6 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_DOC_SINGLETON,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_DOCS_IN_FOLDER,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_FOLDER_URL,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
@@ -110,8 +100,7 @@ def test_include_shared_drives_only_with_size_threshold(
|
||||
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
# 2 extra files from shared drive owned by non-admin and not shared with admin
|
||||
assert len(retrieved_docs) == 52
|
||||
assert len(retrieved_docs) == 50
|
||||
|
||||
|
||||
@patch(
|
||||
@@ -148,8 +137,7 @@ def test_include_shared_drives_only(
|
||||
+ SECTIONS_FILE_IDS
|
||||
)
|
||||
|
||||
# 2 extra files from shared drive owned by non-admin and not shared with admin
|
||||
assert len(retrieved_docs) == 53
|
||||
assert len(retrieved_docs) == 51
|
||||
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
@@ -306,64 +294,6 @@ def test_folders_only(
|
||||
)
|
||||
|
||||
|
||||
def test_shared_folder_owned_by_external_user(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_shared_folder_owned_by_external_user")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=False,
|
||||
include_my_drives=False,
|
||||
include_files_shared_with_me=False,
|
||||
shared_drive_urls=None,
|
||||
shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER
|
||||
|
||||
assert len(retrieved_docs) == len(expected_docs) # 1 for now
|
||||
assert expected_docs[0] in retrieved_docs[0].id
|
||||
|
||||
|
||||
def test_shared_with_me(
|
||||
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
|
||||
) -> None:
|
||||
print("\n\nRunning test_shared_with_me")
|
||||
connector = google_drive_service_acct_connector_factory(
|
||||
primary_admin_email=ADMIN_EMAIL,
|
||||
include_shared_drives=False,
|
||||
include_my_drives=True,
|
||||
include_files_shared_with_me=True,
|
||||
shared_drive_urls=None,
|
||||
shared_folder_urls=None,
|
||||
my_drive_emails=None,
|
||||
)
|
||||
retrieved_docs = load_all_docs(connector)
|
||||
|
||||
print(retrieved_docs)
|
||||
|
||||
expected_file_ids = (
|
||||
ADMIN_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ TEST_USER_1_FILE_IDS
|
||||
+ TEST_USER_2_FILE_IDS
|
||||
+ TEST_USER_3_FILE_IDS
|
||||
)
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=retrieved_docs,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs}
|
||||
for id in retrieved_ids:
|
||||
print(id)
|
||||
|
||||
assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids
|
||||
assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
|
||||
@@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
|
||||
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
|
||||
MAX_DELAY = 60
|
||||
MAX_DELAY = 45
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import requests
|
||||
from requests.models import Response
|
||||
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
@@ -98,24 +97,17 @@ class ChatSessionManager:
|
||||
for data in response_data:
|
||||
if "rephrased_query" in data:
|
||||
analyzed.rephrased_query = data["rephrased_query"]
|
||||
if "tool_name" in data:
|
||||
elif "tool_name" in data:
|
||||
analyzed.tool_name = data["tool_name"]
|
||||
analyzed.tool_result = (
|
||||
data.get("tool_result")
|
||||
if analyzed.tool_name == "run_search"
|
||||
else None
|
||||
)
|
||||
if "relevance_summaries" in data:
|
||||
elif "relevance_summaries" in data:
|
||||
analyzed.relevance_summaries = data["relevance_summaries"]
|
||||
if "answer_piece" in data and data["answer_piece"]:
|
||||
elif "answer_piece" in data and data["answer_piece"]:
|
||||
analyzed.full_message += data["answer_piece"]
|
||||
if "top_documents" in data:
|
||||
assert (
|
||||
analyzed.top_documents is None
|
||||
), "top_documents should only be set once"
|
||||
analyzed.top_documents = [
|
||||
SavedSearchDoc(**doc) for doc in data["top_documents"]
|
||||
]
|
||||
|
||||
return analyzed
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from pydantic import Field
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.server.documents.models import DocumentSource
|
||||
from onyx.server.documents.models import IndexAttemptSnapshot
|
||||
@@ -158,7 +157,7 @@ class StreamedResponse(BaseModel):
|
||||
full_message: str = ""
|
||||
rephrased_query: str | None = None
|
||||
tool_name: str | None = None
|
||||
top_documents: list[SavedSearchDoc] | None = None
|
||||
top_documents: list[dict[str, Any]] | None = None
|
||||
relevance_summaries: list[dict[str, Any]] | None = None
|
||||
tool_result: Any | None = None
|
||||
user: str | None = None
|
||||
|
||||
@@ -5,7 +5,6 @@ This file contains tests for the following:
|
||||
- updates the document sets and user groups to remove the connector
|
||||
- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected
|
||||
"""
|
||||
import os
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -33,13 +32,6 @@ from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
user_group_1: DATestUserGroup
|
||||
user_group_2: DATestUserGroup
|
||||
|
||||
is_ee = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
# create api key
|
||||
@@ -86,17 +78,16 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
|
||||
print("Document sets created and synced")
|
||||
|
||||
if is_ee:
|
||||
# create user groups
|
||||
user_group_1 = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
user_group_2 = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
|
||||
# create user groups
|
||||
user_group_1: DATestUserGroup = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
user_group_2: DATestUserGroup = UserGroupManager.create(
|
||||
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
|
||||
|
||||
# inject a finished index attempt and index attempt error (exercises foreign key errors)
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
@@ -156,13 +147,12 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
)
|
||||
|
||||
# Update local records to match the database for later comparison
|
||||
user_group_1.cc_pair_ids = []
|
||||
user_group_2.cc_pair_ids = [cc_pair_2.id]
|
||||
doc_set_1.cc_pair_ids = []
|
||||
doc_set_2.cc_pair_ids = [cc_pair_2.id]
|
||||
cc_pair_1.groups = []
|
||||
if is_ee:
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
else:
|
||||
cc_pair_2.groups = []
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
|
||||
CCPairManager.wait_for_deletion_completion(
|
||||
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
|
||||
@@ -178,15 +168,11 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
verify_deleted=True,
|
||||
)
|
||||
|
||||
cc_pair_2_group_name_expected = []
|
||||
if is_ee:
|
||||
cc_pair_2_group_name_expected = [user_group_2.name]
|
||||
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
doc_set_names=[doc_set_2.name],
|
||||
group_names=cc_pair_2_group_name_expected,
|
||||
group_names=[user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
verify_deleted=False,
|
||||
)
|
||||
@@ -207,19 +193,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
if is_ee:
|
||||
user_group_1.cc_pair_ids = []
|
||||
user_group_2.cc_pair_ids = [cc_pair_2.id]
|
||||
|
||||
# validate user groups
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
# validate user groups
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_1,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.verify(
|
||||
user_group=user_group_2,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
|
||||
def test_connector_deletion_for_overlapping_connectors(
|
||||
@@ -228,13 +210,6 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
|
||||
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
|
||||
"""
|
||||
user_group_1: DATestUserGroup
|
||||
user_group_2: DATestUserGroup
|
||||
|
||||
is_ee = (
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
# create api key
|
||||
@@ -306,48 +281,47 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
if is_ee:
|
||||
# create a user group and attach it to connector 1
|
||||
user_group_1 = UserGroupManager.create(
|
||||
name="Test User Group 1",
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_1.groups = [user_group_1.id]
|
||||
# create a user group and attach it to connector 1
|
||||
user_group_1: DATestUserGroup = UserGroupManager.create(
|
||||
name="Test User Group 1",
|
||||
cc_pair_ids=[cc_pair_1.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_1],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_1.groups = [user_group_1.id]
|
||||
|
||||
print("User group 1 created and synced")
|
||||
print("User group 1 created and synced")
|
||||
|
||||
# create a user group and attach it to connector 2
|
||||
user_group_2 = UserGroupManager.create(
|
||||
name="Test User Group 2",
|
||||
cc_pair_ids=[cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_2],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
# create a user group and attach it to connector 2
|
||||
user_group_2: DATestUserGroup = UserGroupManager.create(
|
||||
name="Test User Group 2",
|
||||
cc_pair_ids=[cc_pair_2.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
UserGroupManager.wait_for_sync(
|
||||
user_groups_to_check=[user_group_2],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
cc_pair_2.groups = [user_group_2.id]
|
||||
|
||||
print("User group 2 created and synced")
|
||||
print("User group 2 created and synced")
|
||||
|
||||
# verify vespa document is in the user group
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
# verify vespa document is in the user group
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_1,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
group_names=[user_group_1.name, user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
)
|
||||
|
||||
# delete connector 1
|
||||
CCPairManager.pause_cc_pair(
|
||||
@@ -380,15 +354,11 @@ def test_connector_deletion_for_overlapping_connectors(
|
||||
|
||||
# verify the document is not in any document sets
|
||||
# verify the document is only in user group 2
|
||||
group_names_expected = []
|
||||
if is_ee:
|
||||
group_names_expected = [user_group_2.name]
|
||||
|
||||
DocumentManager.verify(
|
||||
vespa_client=vespa_client,
|
||||
cc_pair=cc_pair_2,
|
||||
doc_set_names=[],
|
||||
group_names=group_names_expected,
|
||||
group_names=[user_group_2.name],
|
||||
doc_creating_user=admin_user,
|
||||
verify_deleted=False,
|
||||
)
|
||||
|
||||
@@ -16,7 +16,10 @@ from tests.integration.common_utils.test_models import DATestCCPair
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
# create connectors
|
||||
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
|
||||
user_performing_action=admin_user,
|
||||
@@ -50,13 +53,13 @@ def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -
|
||||
response_json = response.json()
|
||||
|
||||
# Check that the top document is the correct document
|
||||
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
|
||||
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
|
||||
|
||||
# assert that the metadata is correct
|
||||
for doc in cc_pair_1.documents:
|
||||
found_doc = next(
|
||||
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
|
||||
None,
|
||||
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
|
||||
)
|
||||
assert found_doc
|
||||
assert found_doc["metadata"]["document_id"] == doc.id
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.cc_pair import CCPairManager
|
||||
from tests.integration.common_utils.managers.document import DocumentManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.test_models import SimpleTestDocument
|
||||
|
||||
|
||||
DocumentBuilderType = Callable[[list[str]], list[SimpleTestDocument]]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def document_builder(admin_user: DATestUser) -> DocumentBuilderType:
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# create connector
|
||||
cc_pair_1 = CCPairManager.create_from_scratch(
|
||||
source=DocumentSource.INGESTION_API,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
def _document_builder(contents: list[str]) -> list[SimpleTestDocument]:
|
||||
# seed documents
|
||||
docs: list[SimpleTestDocument] = [
|
||||
DocumentManager.seed_doc_with_content(
|
||||
cc_pair=cc_pair_1,
|
||||
content=content,
|
||||
api_key=api_key,
|
||||
)
|
||||
for content in contents
|
||||
]
|
||||
|
||||
return docs
|
||||
|
||||
return _document_builder
|
||||
@@ -5,11 +5,12 @@ import pytest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.tests.streaming_endpoints.conftest import DocumentBuilderType
|
||||
|
||||
|
||||
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
|
||||
def test_send_message_simple_with_history(reset: None) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
@@ -23,44 +24,6 @@ def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -
|
||||
assert len(response.full_message) > 0
|
||||
|
||||
|
||||
def test_send_message__basic_searches(
|
||||
reset: None, admin_user: DATestUser, document_builder: DocumentBuilderType
|
||||
) -> None:
|
||||
MESSAGE = "run a search for 'test'"
|
||||
SHORT_DOC_CONTENT = "test"
|
||||
LONG_DOC_CONTENT = "blah blah blah blah" * 100
|
||||
|
||||
LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
short_doc = document_builder([SHORT_DOC_CONTENT])[0]
|
||||
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message=MESSAGE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.top_documents is not None
|
||||
assert len(response.top_documents) == 1
|
||||
assert response.top_documents[0].document_id == short_doc.id
|
||||
|
||||
# make sure this doc is really long so that it will be split into multiple chunks
|
||||
long_doc = document_builder([LONG_DOC_CONTENT])[0]
|
||||
|
||||
# new chat session for simplicity
|
||||
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message=MESSAGE,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
assert response.top_documents is not None
|
||||
assert len(response.top_documents) == 2
|
||||
# short doc should be more relevant and thus first
|
||||
assert response.top_documents[0].document_id == short_doc.id
|
||||
assert response.top_documents[1].document_id == long_doc.id
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="enable for autorun when we have a testing environment with semantically useful data"
|
||||
)
|
||||
|
||||
@@ -9,6 +9,8 @@ from onyx.chat.chat_utils import llm_doc_from_inference_section
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -17,6 +19,7 @@ from onyx.context.search.models import InferenceSection
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
@@ -117,7 +120,24 @@ def mock_search_results(
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
||||
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
|
||||
return OnyxContexts(
|
||||
contexts=[
|
||||
OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
for section in mock_inference_sections
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_search_tool(
|
||||
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
|
||||
) -> MagicMock:
|
||||
mock_tool = MagicMock(spec=SearchTool)
|
||||
mock_tool.name = "search"
|
||||
mock_tool.build_tool_message_content.return_value = "search_response"
|
||||
@@ -126,6 +146,7 @@ def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
|
||||
json.loads(doc.model_dump_json()) for doc in mock_search_results
|
||||
]
|
||||
mock_tool.run.return_value = [
|
||||
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
|
||||
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
|
||||
]
|
||||
mock_tool.tool_definition.return_value = {
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
@@ -32,6 +33,7 @@ from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
@@ -139,6 +141,7 @@ def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
|
||||
def test_answer_with_search_call(
|
||||
answer_instance: Answer,
|
||||
mock_search_results: list[LlmDoc],
|
||||
mock_contexts: OnyxContexts,
|
||||
mock_search_tool: MagicMock,
|
||||
force_use_tool: ForceUseTool,
|
||||
expected_tool_args: dict,
|
||||
@@ -194,21 +197,25 @@ def test_answer_with_search_call(
|
||||
tool_name="search", tool_args=expected_tool_args
|
||||
)
|
||||
assert output[1] == ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=mock_contexts,
|
||||
)
|
||||
assert output[2] == ToolResponse(
|
||||
id="final_context_documents",
|
||||
response=mock_search_results,
|
||||
)
|
||||
assert output[2] == ToolCallFinalResult(
|
||||
assert output[3] == ToolCallFinalResult(
|
||||
tool_name="search",
|
||||
tool_args=expected_tool_args,
|
||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||
)
|
||||
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||
assert output[4] == expected_citation
|
||||
assert output[5] == OnyxAnswerPiece(
|
||||
assert output[5] == expected_citation
|
||||
assert output[6] == OnyxAnswerPiece(
|
||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||
)
|
||||
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
|
||||
expected_answer = (
|
||||
"Based on the search results, "
|
||||
@@ -261,6 +268,7 @@ def test_answer_with_search_call(
|
||||
def test_answer_with_search_no_tool_calling(
|
||||
answer_instance: Answer,
|
||||
mock_search_results: list[LlmDoc],
|
||||
mock_contexts: OnyxContexts,
|
||||
mock_search_tool: MagicMock,
|
||||
) -> None:
|
||||
answer_instance.graph_config.tooling.tools = [mock_search_tool]
|
||||
@@ -280,26 +288,30 @@ def test_answer_with_search_no_tool_calling(
|
||||
output = list(answer_instance.processed_streamed_output)
|
||||
|
||||
# Assertions
|
||||
assert len(output) == 7
|
||||
assert len(output) == 8
|
||||
assert output[0] == ToolCallKickoff(
|
||||
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
|
||||
)
|
||||
assert output[1] == ToolResponse(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=mock_contexts,
|
||||
)
|
||||
assert output[2] == ToolResponse(
|
||||
id=FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
response=mock_search_results,
|
||||
)
|
||||
assert output[2] == ToolCallFinalResult(
|
||||
assert output[3] == ToolCallFinalResult(
|
||||
tool_name="search",
|
||||
tool_args=DEFAULT_SEARCH_ARGS,
|
||||
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
|
||||
)
|
||||
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
|
||||
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
|
||||
assert output[4] == expected_citation
|
||||
assert output[5] == OnyxAnswerPiece(
|
||||
assert output[5] == expected_citation
|
||||
assert output[6] == OnyxAnswerPiece(
|
||||
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
|
||||
)
|
||||
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
|
||||
|
||||
expected_answer = (
|
||||
"Based on the search results, "
|
||||
|
||||
@@ -79,7 +79,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
for res in results:
|
||||
print(res)
|
||||
|
||||
expected_count = 3 if skip_gen_ai_answer_generation else 4
|
||||
expected_count = 4 if skip_gen_ai_answer_generation else 5
|
||||
assert len(results) == expected_count
|
||||
if not skip_gen_ai_answer_generation:
|
||||
mock_llm.stream.assert_called_once()
|
||||
|
||||
@@ -45,7 +45,7 @@ export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
|
||||
className="mr-1 my-auto cursor-pointer"
|
||||
onClick={() =>
|
||||
router.push(
|
||||
`/admin/actions/edit/${tool.id}?u=${Date.now()}`
|
||||
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
@@ -281,7 +281,7 @@ export default function AddConnector({
|
||||
return (
|
||||
<Formik
|
||||
initialValues={{
|
||||
...createConnectorInitialValues(connector, currentCredential),
|
||||
...createConnectorInitialValues(connector),
|
||||
...Object.fromEntries(
|
||||
connectorConfigs[connector].advanced_values.map((field) => [
|
||||
field.name,
|
||||
|
||||
@@ -148,7 +148,8 @@ export function Explorer({
|
||||
clearTimeout(timeoutId);
|
||||
}
|
||||
|
||||
if (query && query.trim() !== "") {
|
||||
let doSearch = true;
|
||||
if (doSearch) {
|
||||
router.replace(
|
||||
`/admin/documents/explorer?query=${encodeURIComponent(query)}`
|
||||
);
|
||||
|
||||
@@ -34,12 +34,11 @@
|
||||
/* -------------------------------------------------------
|
||||
* 2. Keep special, custom, or near-duplicate background
|
||||
* ------------------------------------------------------- */
|
||||
--background: #fefcfa; /* slightly off-white */
|
||||
--background-50: #fffdfb; /* a little lighter than background but not quite white */
|
||||
--background: #fefcfa; /* slightly off-white, keep it */
|
||||
--input-background: #fefcfa;
|
||||
--input-border: #f1eee8;
|
||||
--text-text: #f4f2ed;
|
||||
--background-dark: #141414;
|
||||
--background-dark: #e9e6e0;
|
||||
--new-background: #ebe7de;
|
||||
--new-background-light: #d9d1c0;
|
||||
--background-chatbar: #f5f3ee;
|
||||
@@ -235,7 +234,6 @@
|
||||
|
||||
--text-text: #1d1d1d;
|
||||
--background-dark: #252525;
|
||||
--background-50: #252525;
|
||||
|
||||
/* --new-background: #fff; */
|
||||
--new-background: #2c2c2c;
|
||||
|
||||
@@ -181,7 +181,7 @@ const SignedUpUserTable = ({
|
||||
: "All Roles"}
|
||||
</SelectValue>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="bg-background-50">
|
||||
<SelectContent className="bg-background">
|
||||
{Object.entries(USER_ROLE_LABELS)
|
||||
.filter(([role]) => role !== UserRole.EXT_PERM_USER)
|
||||
.map(([role, label]) => (
|
||||
|
||||
@@ -26,13 +26,7 @@ export const buildDocumentSummaryDisplay = (
|
||||
matchHighlights: string[],
|
||||
blurb: string
|
||||
) => {
|
||||
// if there are no match highlights, or if it's really short, just use the blurb
|
||||
// this is to prevent the UI from showing something like `...` for the summary
|
||||
const MIN_MATCH_HIGHLIGHT_LENGTH = 5;
|
||||
if (
|
||||
!matchHighlights ||
|
||||
matchHighlights.length <= MIN_MATCH_HIGHLIGHT_LENGTH
|
||||
) {
|
||||
if (!matchHighlights || matchHighlights.length === 0) {
|
||||
return blurb;
|
||||
}
|
||||
|
||||
|
||||
@@ -102,7 +102,7 @@ export function UserProvider({
|
||||
};
|
||||
|
||||
// Use the custom token refresh hook
|
||||
useTokenRefresh(upToDateUser, fetchUser);
|
||||
// useTokenRefresh(upToDateUser, fetchUser);
|
||||
|
||||
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
|
||||
try {
|
||||
|
||||
@@ -1292,8 +1292,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
|
||||
},
|
||||
};
|
||||
export function createConnectorInitialValues(
|
||||
connector: ConfigurableSources,
|
||||
currentCredential: Credential<any> | null = null
|
||||
connector: ConfigurableSources
|
||||
): Record<string, any> & AccessTypeGroupSelectorFormType {
|
||||
const configuration = connectorConfigs[connector];
|
||||
|
||||
@@ -1308,16 +1307,7 @@ export function createConnectorInitialValues(
|
||||
} else if (field.type === "list") {
|
||||
acc[field.name] = field.default || [];
|
||||
} else if (field.type === "checkbox") {
|
||||
// Special case for include_files_shared_with_me when using service account
|
||||
if (
|
||||
field.name === "include_files_shared_with_me" &&
|
||||
currentCredential &&
|
||||
!currentCredential.credential_json?.google_tokens
|
||||
) {
|
||||
acc[field.name] = true;
|
||||
} else {
|
||||
acc[field.name] = field.default || false;
|
||||
}
|
||||
acc[field.name] = field.default || false;
|
||||
} else if (field.default !== undefined) {
|
||||
acc[field.name] = field.default;
|
||||
}
|
||||
|
||||
@@ -108,7 +108,6 @@ module.exports = {
|
||||
"accent-background": "var(--accent-background)",
|
||||
"accent-background-hovered": "var(--accent-background-hovered)",
|
||||
"accent-background-selected": "var(--accent-background-selected)",
|
||||
"background-50": "var(--background-50)",
|
||||
"background-dark": "var(--off-white)",
|
||||
"background-100": "var(--neutral-100-border-light)",
|
||||
"background-125": "var(--neutral-125)",
|
||||
|
||||
Reference in New Issue
Block a user