Compare commits

..

1 Commits

Author SHA1 Message Date
Richard Kuo (Onyx)
54e61611c5 prototype for surfacing docs without a query 2025-03-27 16:52:31 -07:00
64 changed files with 437 additions and 1115 deletions

View File

@@ -1,209 +0,0 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -93,12 +93,12 @@ def _get_access_for_documents(
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
is_public=is_public_anywhere,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
)
return access_map

View File

@@ -2,6 +2,7 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
@@ -31,6 +32,8 @@ def gather_stream_for_answer_api(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
elif isinstance(packet, OnyxContexts):
response.contexts = packet
if answer:
response.answer = answer

View File

@@ -14,6 +14,7 @@ from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
@@ -55,6 +56,25 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _translate_doc_response_to_simple_doc(
doc_response: QADocsResponse,
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,
match_highlights=[
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
@@ -91,6 +111,9 @@ def _convert_packet_stream_to_response(
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)

View File

@@ -8,6 +8,7 @@ from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswer
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
@@ -163,6 +164,8 @@ class ChatBasicResponse(BaseModel):
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
# TODO: deprecate both of these
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
@@ -217,3 +220,4 @@ class OneShotQAResponse(BaseModel):
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
contexts: OnyxContexts | None = None

View File

@@ -18,7 +18,7 @@ def _get_access_for_document(
document_id=document_id,
)
doc_access = DocumentAccess.build(
return DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
@@ -26,8 +26,6 @@ def _get_access_for_document(
is_public=info[2] if info else False,
)
return doc_access
def get_access_for_document(
document_id: str,
@@ -40,12 +38,12 @@ def get_access_for_document(
def get_null_document_access() -> DocumentAccess:
return DocumentAccess.build(
user_emails=[],
user_groups=[],
return DocumentAccess(
user_emails=set(),
user_groups=set(),
is_public=False,
external_user_emails=[],
external_user_group_ids=[],
external_user_emails=set(),
external_user_group_ids=set(),
)
@@ -58,18 +56,18 @@ def _get_access_for_documents(
document_ids=document_ids,
)
doc_access = {
document_id: DocumentAccess.build(
user_emails=[email for email in user_emails if email],
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
# MIT version will wipe all groups and external groups on update
user_groups=[],
user_groups=set(),
is_public=is_public,
external_user_emails=[],
external_user_group_ids=[],
external_user_emails=set(),
external_user_group_ids=set(),
)
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not been indexed by the indexing job yet, in those cases
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.

View File

@@ -56,45 +56,33 @@ class DocExternalAccess:
)
@dataclass(frozen=True, init=False)
@dataclass(frozen=True)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
external_user_emails: set[str]
external_user_group_ids: set[str]
is_public: bool
def __init__(self) -> None:
raise TypeError(
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:
acl_set.add(prefix_user_email(user_email))
for group_name in self.user_groups:
acl_set.add(prefix_user_group(group_name))
for external_user_email in self.external_user_emails:
acl_set.add(prefix_user_email(external_user_email))
for external_group_id in self.external_user_group_ids:
acl_set.add(prefix_external_group(external_group_id))
if self.is_public:
acl_set.add(PUBLIC_DOC_PAT)
return acl_set
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
)
@classmethod
def build(
@@ -105,32 +93,29 @@ class DocumentAccess(ExternalAccess):
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
obj = object.__new__(cls)
object.__setattr__(
obj, "user_emails", {user_email for user_email in user_emails if user_email}
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
)
object.__setattr__(obj, "user_groups", set(user_groups))
object.__setattr__(
obj,
"external_user_emails",
{external_email for external_email in external_user_emails},
)
object.__setattr__(
obj,
"external_user_group_ids",
{external_group_id for external_group_id in external_user_group_ids},
)
object.__setattr__(obj, "is_public", is_public)
return obj
default_public_access = DocumentAccess.build(
external_user_emails=[],
external_user_group_ids=[],
user_emails=[],
user_groups=[],
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
is_public=True,
)

View File

@@ -7,6 +7,7 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
@@ -23,7 +24,7 @@ def process_llm_stream(
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")

View File

@@ -156,6 +156,7 @@ def generate_initial_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -183,6 +183,7 @@ def generate_validate_refined_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -57,6 +57,7 @@ def format_results(
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -13,7 +13,9 @@ from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -57,7 +59,9 @@ def basic_use_tool_response(
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(section_to_llm_doc(section))
initial_search_results.append(
context_from_inference_section(section)
)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -389,8 +389,6 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
credential_id_to_delete: int | None = None
connector_id_to_delete: int | None = None
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -445,35 +443,26 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# Flush to ensure document deletion happens before connector deletion
db_session.flush()
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=connector_id_to_delete,
connector_id=cc_pair.connector_id,
)
if not connector or not len(connector.credentials):
task_logger.info(
@@ -506,15 +495,15 @@ def monitor_connector_deletion_taskset(
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={connector_id_to_delete} "
f"credential={credential_id_to_delete} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"docs_deleted={fence_data.num_tasks}"
)
@@ -564,7 +553,7 @@ def validate_connector_deletion_fences(
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_upsert_tasks: set[str],
queued_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
@@ -651,7 +640,7 @@ def validate_connector_deletion_fence(
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_upsert_tasks:
if member_str in queued_tasks:
continue
tasks_not_in_celery += 1

View File

@@ -194,6 +194,17 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class OnyxContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str
class OnyxContexts(BaseModel):
contexts: list[OnyxContext]
class OnyxAnswer(BaseModel):
answer: str | None
@@ -259,6 +270,7 @@ class PersonaOverrideConfig(BaseModel):
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| OnyxContexts
| FileChatDisplay
| CustomToolResponse
| StreamingError

View File

@@ -29,6 +29,7 @@ from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
@@ -130,6 +131,7 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -298,6 +300,7 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| OnyxContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -916,6 +919,8 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(OnyxContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:

View File

@@ -301,10 +301,6 @@ def prune_sections(
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)

View File

@@ -3,6 +3,7 @@ from collections.abc import Sequence
from pydantic import BaseModel
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceChunk
@@ -11,7 +12,7 @@ class DocumentIdOrderMapping(BaseModel):
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
) -> DocumentIdOrderMapping:
order_mapping = {}
current = 1 if one_indexed else 0

View File

@@ -28,9 +28,7 @@ from onyx.connectors.google_drive.doc_conversion import (
)
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import DriveRetrievalStage
@@ -88,18 +86,13 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
# We used to always get the user email from the file owners when available,
# but this was causing issues with shared folders where the owner was not included in the service account
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
# wanting to retry with file owners and/or admin email at some point.
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = retriever_email
# Only construct these services when needed
user_drive_service = lazy_eval(
lambda: get_drive_service(creds, user_email=user_email)
@@ -457,11 +450,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from add_retrieval_info(
get_all_files_in_my_drive_and_shared(
get_all_files_in_my_drive(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
include_shared_with_me=self.include_files_shared_with_me,
start=curr_stage.completed_until if resuming else start,
end=end,
),
@@ -924,28 +916,20 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[RetrievedDriveFile] = []
files_batch: list[GoogleDriveFileType] = []
def _yield_batch(
files_batch: list[RetrievedDriveFile],
files_batch: list[GoogleDriveFileType],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [
(
convert_func,
(
file.user_email,
file.drive_file,
),
)
for file in files_batch
]
func_with_args = [(convert_func, (file,)) for file in files_batch]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
@@ -983,7 +967,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
continue
files_batch.append(retrieved_file)
files_batch.append(retrieved_file.drive_file)
if len(files_batch) < self.batch_size:
continue

View File

@@ -87,17 +87,35 @@ def _download_and_extract_sections_basic(
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
@@ -106,100 +124,88 @@ def _download_and_extract_sections_basic(
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
logger.warning(f"Failed to download {file_name}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
def convert_drive_item_to_document(

View File

@@ -214,11 +214,10 @@ def get_files_in_shared_drive(
yield file
def get_all_files_in_my_drive_and_shared(
def get_all_files_in_my_drive(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool,
include_shared_with_me: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
@@ -230,8 +229,7 @@ def get_all_files_in_my_drive_and_shared(
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -248,8 +246,7 @@ def get_all_files_in_my_drive_and_shared(
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,

View File

@@ -75,7 +75,7 @@ class HighspotClient:
self.key = key
self.secret = secret
self.base_url = base_url.rstrip("/") + "/"
self.base_url = base_url
self.timeout = timeout
# Set up session with retry logic

View File

@@ -339,12 +339,6 @@ class SearchPipeline:
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def merged_retrieved_sections(self) -> list[InferenceSection]:
"""Should be used to display in the UI in order to prevent displaying
multiple sections for the same document as separate "documents"."""
return _merge_sections(sections=self.retrieved_sections)
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -421,10 +415,6 @@ class SearchPipeline:
raise ValueError(
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
)
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
# since the property sets the generator. DO NOT REMOVE.
_ = self.final_context_sections
self._section_relevance = next(
cast(
Iterator[list[SectionRelevancePiece]],

View File

@@ -821,26 +821,30 @@ class VespaIndex(DocumentIndex):
num_to_retrieve: int = NUM_RETURNED_HITS,
offset: int = 0,
) -> list[InferenceChunkUncleaned]:
vespa_where_clauses = build_vespa_filters(filters, include_hidden=True)
yql = (
YQL_BASE.format(index_name=self.index_name)
+ vespa_where_clauses
+ '({grammar: "weakAnd"}userInput(@query) '
# `({defaultIndex: "content_summary"}userInput(@query))` section is
# needed for highlighting while the N-gram highlighting is broken /
# not working as desired
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
vespa_where_clauses = build_vespa_filters(
filters, include_hidden=True, remove_trailing_and=True
)
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
params: dict[str, str | int] = {
"yql": yql,
"query": query,
"hits": num_to_retrieve,
"offset": 0,
"ranking.profile": "admin_search",
"timeout": VESPA_TIMEOUT,
}
if len(query.strip()) > 0:
yql += (
' and ({grammar: "weakAnd"}userInput(@query) '
# `({defaultIndex: "content_summary"}userInput(@query))` section is
# needed for highlighting while the N-gram highlighting is broken /
# not working as desired
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
)
params["yql"] = yql
params["query"] = query
return query_vespa(params)
# Retrieves chunk information for a document:

View File

@@ -12,6 +12,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
@@ -41,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from onyx.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
@@ -54,6 +58,7 @@ from onyx.utils.special_types import JSON_ro
logger = setup_logger()
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
@@ -352,13 +357,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
)
yield from yield_search_responses(
query=query,
# give back the merged sections to prevent duplicate docs from appearing in the UI
get_retrieved_sections=lambda: search_pipeline.merged_retrieved_sections,
get_final_context_sections=lambda: search_pipeline.final_context_sections,
search_query_info=search_query_info,
get_section_relevance=lambda: search_pipeline.section_relevance,
search_tool=self,
query,
lambda: search_pipeline.retrieved_sections,
lambda: search_pipeline.reranked_sections,
lambda: search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
@@ -400,6 +405,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
def yield_search_responses(
query: str,
get_retrieved_sections: Callable[[], list[InferenceSection]],
get_reranked_sections: Callable[[], list[InferenceSection]],
get_final_context_sections: Callable[[], list[InferenceSection]],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
@@ -417,6 +423,16 @@ def yield_search_responses(
),
)
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
context_from_inference_section(section)
for section in get_reranked_sections()
]
),
)
section_relevance = get_section_relevance()
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,

View File

@@ -1,4 +1,5 @@
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceSection
from onyx.prompts.prompt_utils import clean_up_source
@@ -31,23 +32,10 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
return doc_dict
def section_to_llm_doc(section: InferenceSection) -> LlmDoc:
possible_link_chunks = [section.center_chunk] + section.chunks
link: str | None = None
for chunk in possible_link_chunks:
if chunk.source_links:
link = list(chunk.source_links.values())[0]
break
return LlmDoc(
document_id=section.center_chunk.document_id,
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
return OnyxContext(
content=section.combined_content,
source_type=section.center_chunk.source_type,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
metadata=section.center_chunk.metadata,
updated_at=section.center_chunk.updated_at,
blurb=section.center_chunk.blurb,
link=link,
source_links=section.center_chunk.source_links,
match_highlights=section.center_chunk.match_highlights,
)

View File

@@ -78,19 +78,19 @@ def generate_dummy_chunk(
for i in range(number_of_document_sets):
document_set_names.append(f"Document Set {i}")
user_emails: list[str | None] = []
user_groups: list[str] = []
external_user_emails: list[str] = []
external_user_group_ids: list[str] = []
user_emails: set[str | None] = set()
user_groups: set[str] = set()
external_user_emails: set[str] = set()
external_user_group_ids: set[str] = set()
for i in range(number_of_acl_entries):
user_emails.append(f"user_{i}@example.com")
user_groups.append(f"group_{i}")
external_user_emails.append(f"external_user_{i}@example.com")
external_user_group_ids.append(f"external_group_{i}")
user_emails.add(f"user_{i}@example.com")
user_groups.add(f"group_{i}")
external_user_emails.add(f"external_user_{i}@example.com")
external_user_group_ids.add(f"external_group_{i}")
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=DocumentAccess.build(
access=DocumentAccess(
user_emails=user_emails,
user_groups=user_groups,
external_user_emails=external_user_emails,

View File

@@ -58,16 +58,6 @@ SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
EXTERNAL_SHARED_FOLDER_URL = (
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
)
EXTERNAL_SHARED_DOCS_IN_FOLDER = [
"https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY"
]
EXTERNAL_SHARED_DOC_SINGLETON = (
"https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA"
)
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
ADMIN_EMAIL = "admin@onyx-test.com"

View File

@@ -1,7 +1,6 @@
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
@@ -10,15 +9,6 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOC_SINGLETON,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOCS_IN_FOLDER,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_FOLDER_URL,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
@@ -110,8 +100,7 @@ def test_include_shared_drives_only_with_size_threshold(
retrieved_docs = load_all_docs(connector)
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 52
assert len(retrieved_docs) == 50
@patch(
@@ -148,8 +137,7 @@ def test_include_shared_drives_only(
+ SECTIONS_FILE_IDS
)
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 53
assert len(retrieved_docs) == 51
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
@@ -306,64 +294,6 @@ def test_folders_only(
)
def test_shared_folder_owned_by_external_user(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_folder_owned_by_external_user")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=False,
include_files_shared_with_me=False,
shared_drive_urls=None,
shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER
assert len(retrieved_docs) == len(expected_docs) # 1 for now
assert expected_docs[0] in retrieved_docs[0].id
def test_shared_with_me(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_with_me")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=True,
include_files_shared_with_me=True,
shared_drive_urls=None,
shared_folder_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
print(retrieved_docs)
expected_file_ids = (
ADMIN_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
+ TEST_USER_1_FILE_IDS
+ TEST_USER_2_FILE_IDS
+ TEST_USER_3_FILE_IDS
)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs}
for id in retrieved_ids:
print(id)
assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids
assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,

View File

@@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 60
MAX_DELAY = 45
GENERAL_HEADERS = {"Content-Type": "application/json"}

View File

@@ -5,7 +5,6 @@ import requests
from requests.models import Response
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
@@ -98,24 +97,17 @@ class ChatSessionManager:
for data in response_data:
if "rephrased_query" in data:
analyzed.rephrased_query = data["rephrased_query"]
if "tool_name" in data:
elif "tool_name" in data:
analyzed.tool_name = data["tool_name"]
analyzed.tool_result = (
data.get("tool_result")
if analyzed.tool_name == "run_search"
else None
)
if "relevance_summaries" in data:
elif "relevance_summaries" in data:
analyzed.relevance_summaries = data["relevance_summaries"]
if "answer_piece" in data and data["answer_piece"]:
elif "answer_piece" in data and data["answer_piece"]:
analyzed.full_message += data["answer_piece"]
if "top_documents" in data:
assert (
analyzed.top_documents is None
), "top_documents should only be set once"
analyzed.top_documents = [
SavedSearchDoc(**doc) for doc in data["top_documents"]
]
return analyzed

View File

@@ -10,7 +10,6 @@ from pydantic import Field
from onyx.auth.schemas import UserRole
from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.models import SavedSearchDoc
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import IndexAttemptSnapshot
@@ -158,7 +157,7 @@ class StreamedResponse(BaseModel):
full_message: str = ""
rephrased_query: str | None = None
tool_name: str | None = None
top_documents: list[SavedSearchDoc] | None = None
top_documents: list[dict[str, Any]] | None = None
relevance_summaries: list[dict[str, Any]] | None = None
tool_result: Any | None = None
user: str | None = None

View File

@@ -1,6 +1,3 @@
import os
import pytest
import requests
from onyx.auth.schemas import UserRole
@@ -9,10 +6,6 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_user_conversion(reset: None) -> None:
"""
Test that SAML login correctly converts users with non-authenticated roles

View File

@@ -5,7 +5,6 @@ This file contains tests for the following:
- updates the document sets and user groups to remove the connector
- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected
"""
import os
from uuid import uuid4
from sqlalchemy.orm import Session
@@ -33,13 +32,6 @@ from tests.integration.common_utils.vespa import vespa_fixture
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_group_1: DATestUserGroup
user_group_2: DATestUserGroup
is_ee = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
@@ -86,17 +78,16 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
print("Document sets created and synced")
if is_ee:
# create user groups
user_group_1 = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2 = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
# create user groups
user_group_1: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
# inject a finished index attempt and index attempt error (exercises foreign key errors)
with Session(get_sqlalchemy_engine()) as db_session:
@@ -156,13 +147,12 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
)
# Update local records to match the database for later comparison
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
doc_set_1.cc_pair_ids = []
doc_set_2.cc_pair_ids = [cc_pair_2.id]
cc_pair_1.groups = []
if is_ee:
cc_pair_2.groups = [user_group_2.id]
else:
cc_pair_2.groups = []
cc_pair_2.groups = [user_group_2.id]
CCPairManager.wait_for_deletion_completion(
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
@@ -178,15 +168,11 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
verify_deleted=True,
)
cc_pair_2_group_name_expected = []
if is_ee:
cc_pair_2_group_name_expected = [user_group_2.name]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[doc_set_2.name],
group_names=cc_pair_2_group_name_expected,
group_names=[user_group_2.name],
doc_creating_user=admin_user,
verify_deleted=False,
)
@@ -207,19 +193,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
if is_ee:
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
def test_connector_deletion_for_overlapping_connectors(
@@ -228,13 +210,6 @@ def test_connector_deletion_for_overlapping_connectors(
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
"""
user_group_1: DATestUserGroup
user_group_2: DATestUserGroup
is_ee = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
@@ -306,48 +281,47 @@ def test_connector_deletion_for_overlapping_connectors(
doc_creating_user=admin_user,
)
if is_ee:
# create a user group and attach it to connector 1
user_group_1 = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
# create a user group and attach it to connector 1
user_group_1: DATestUserGroup = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
print("User group 1 created and synced")
print("User group 1 created and synced")
# create a user group and attach it to connector 2
user_group_2 = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
# create a user group and attach it to connector 2
user_group_2: DATestUserGroup = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
print("User group 2 created and synced")
print("User group 2 created and synced")
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# delete connector 1
CCPairManager.pause_cc_pair(
@@ -380,15 +354,11 @@ def test_connector_deletion_for_overlapping_connectors(
# verify the document is not in any document sets
# verify the document is only in user group 2
group_names_expected = []
if is_ee:
group_names_expected = [user_group_2.name]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[],
group_names=group_names_expected,
group_names=[user_group_2.name],
doc_creating_user=admin_user,
verify_deleted=False,
)

View File

@@ -1,6 +1,3 @@
import os
import pytest
import requests
from onyx.configs.constants import MessageType
@@ -15,10 +12,6 @@ from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history is enterprise only",
)
def test_all_stream_chat_message_objects_outputs(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,7 +1,5 @@
import json
import os
import pytest
import requests
from onyx.configs.constants import MessageType
@@ -18,11 +16,10 @@ from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
def test_send_message_simple_with_history(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create connectors
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
@@ -56,22 +53,18 @@ def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -
response_json = response.json()
# Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
# assert that the metadata is correct
for doc in cc_pair_1.documents:
found_doc = next(
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
None,
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
)
assert found_doc
assert found_doc["metadata"]["document_id"] == doc.id
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
@@ -161,10 +154,6 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history_strict_json(
new_admin_user: DATestUser | None,
) -> None:

View File

@@ -2,8 +2,6 @@
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating connector-credential pairs.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -17,10 +15,6 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and User Group tests are enterprise only",
)
def test_cc_pair_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -2,8 +2,6 @@
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating connectors.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -15,10 +13,6 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_connector_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -2,8 +2,6 @@
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating credentials.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -14,10 +12,6 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_credential_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,5 +1,3 @@
import os
import pytest
from requests.exceptions import HTTPError
@@ -12,10 +10,6 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_doc_set_permissions_setup(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -4,8 +4,6 @@ This file tests the permissions for creating and editing personas for different
- Curators can edit personas that belong exclusively to groups they curate
- Admins can edit all personas
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -15,10 +13,6 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_persona_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,8 +1,6 @@
"""
This file tests the ability of different user types to set the role of other users.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -12,10 +10,6 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_user_role_setting_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,10 +1,6 @@
"""
This test tests the happy path for curator permissions
"""
import os
import pytest
from onyx.db.enums import AccessType
from onyx.db.models import UserRole
from onyx.server.documents.models import DocumentSource
@@ -16,10 +12,6 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator tests are enterprise only",
)
def test_whole_curator_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
@@ -97,10 +89,6 @@ def test_whole_curator_flow(reset: None) -> None:
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator tests are enterprise only",
)
def test_global_curator_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,4 +1,3 @@
import os
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -64,10 +63,6 @@ def setup_chat_session(reset: None) -> tuple[DATestUser, str]:
return admin_user, str(chat_session.id)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Chat history tests are enterprise only",
)
def test_chat_history_endpoints(
reset: None, setup_chat_session: tuple[DATestUser, str]
) -> None:
@@ -121,10 +116,6 @@ def test_chat_history_endpoints(
assert len(history_response.items) == 0
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Chat history tests are enterprise only",
)
def test_chat_history_csv_export(
reset: None, setup_chat_session: tuple[DATestUser, str]
) -> None:

View File

@@ -1,8 +1,5 @@
import os
from datetime import datetime
import pytest
from onyx.configs.constants import QAFeedbackType
from tests.integration.common_utils.managers.query_history import QueryHistoryManager
from tests.integration.common_utils.test_models import DAQueryHistoryEntry
@@ -50,10 +47,6 @@ def _verify_query_history_pagination(
assert all_expected_sessions == all_retrieved_sessions
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Query history tests are enterprise only",
)
def test_query_history_pagination(reset: None) -> None:
(
admin_user,

View File

@@ -1,42 +0,0 @@
from collections.abc import Callable
import pytest
from onyx.configs.constants import DocumentSource
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import SimpleTestDocument
DocumentBuilderType = Callable[[list[str]], list[SimpleTestDocument]]
@pytest.fixture
def document_builder(admin_user: DATestUser) -> DocumentBuilderType:
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
# create connector
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
def _document_builder(contents: list[str]) -> list[SimpleTestDocument]:
# seed documents
docs: list[SimpleTestDocument] = [
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content=content,
api_key=api_key,
)
for content in contents
]
return docs
return _document_builder

View File

@@ -5,11 +5,12 @@ import pytest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.tests.streaming_endpoints.conftest import DocumentBuilderType
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
def test_send_message_simple_with_history(reset: None) -> None:
admin_user: DATestUser = UserManager.create(name="admin_user")
LLMProviderManager.create(user_performing_action=admin_user)
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
@@ -23,44 +24,6 @@ def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -
assert len(response.full_message) > 0
def test_send_message__basic_searches(
reset: None, admin_user: DATestUser, document_builder: DocumentBuilderType
) -> None:
MESSAGE = "run a search for 'test'"
SHORT_DOC_CONTENT = "test"
LONG_DOC_CONTENT = "blah blah blah blah" * 100
LLMProviderManager.create(user_performing_action=admin_user)
short_doc = document_builder([SHORT_DOC_CONTENT])[0]
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message=MESSAGE,
user_performing_action=admin_user,
)
assert response.top_documents is not None
assert len(response.top_documents) == 1
assert response.top_documents[0].document_id == short_doc.id
# make sure this doc is really long so that it will be split into multiple chunks
long_doc = document_builder([LONG_DOC_CONTENT])[0]
# new chat session for simplicity
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message=MESSAGE,
user_performing_action=admin_user,
)
assert response.top_documents is not None
assert len(response.top_documents) == 2
# short doc should be more relevant and thus first
assert response.top_documents[0].document_id == short_doc.id
assert response.top_documents[1].document_id == long_doc.id
@pytest.mark.skip(
reason="enable for autorun when we have a testing environment with semantically useful data"
)

View File

@@ -8,10 +8,6 @@ This tests the deletion of a user group with the following foreign key constrain
- token_rate_limit (Not Implemented)
- persona
"""
import os
import pytest
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.credential import CredentialManager
@@ -29,10 +25,6 @@ from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_user_group_deletion(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,7 +1,3 @@
import os
import pytest
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import APIKeyManager
@@ -15,10 +11,6 @@ from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -9,6 +9,8 @@ from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import DocumentSource
@@ -17,6 +19,7 @@ from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
@@ -117,7 +120,24 @@ def mock_search_results(
@pytest.fixture
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
return OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in mock_inference_sections
]
)
@pytest.fixture
def mock_search_tool(
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
) -> MagicMock:
mock_tool = MagicMock(spec=SearchTool)
mock_tool.name = "search"
mock_tool.build_tool_message_content.return_value = "search_response"
@@ -126,6 +146,7 @@ def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
json.loads(doc.model_dump_json()) for doc in mock_search_results
]
mock_tool.run.return_value = [
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
]
mock_tool.tool_definition.return_value = {

View File

@@ -19,6 +19,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
@@ -32,6 +33,7 @@ from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -139,6 +141,7 @@ def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
def test_answer_with_search_call(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock,
force_use_tool: ForceUseTool,
expected_tool_args: dict,
@@ -194,21 +197,25 @@ def test_answer_with_search_call(
tool_name="search", tool_args=expected_tool_args
)
assert output[1] == ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=mock_contexts,
)
assert output[2] == ToolResponse(
id="final_context_documents",
response=mock_search_results,
)
assert output[2] == ToolCallFinalResult(
assert output[3] == ToolCallFinalResult(
tool_name="search",
tool_args=expected_tool_args,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[4] == expected_citation
assert output[5] == OnyxAnswerPiece(
assert output[5] == expected_citation
assert output[6] == OnyxAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "
@@ -261,6 +268,7 @@ def test_answer_with_search_call(
def test_answer_with_search_no_tool_calling(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock,
) -> None:
answer_instance.graph_config.tooling.tools = [mock_search_tool]
@@ -280,26 +288,30 @@ def test_answer_with_search_no_tool_calling(
output = list(answer_instance.processed_streamed_output)
# Assertions
assert len(output) == 7
assert len(output) == 8
assert output[0] == ToolCallKickoff(
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
)
assert output[1] == ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=mock_contexts,
)
assert output[2] == ToolResponse(
id=FINAL_CONTEXT_DOCUMENTS_ID,
response=mock_search_results,
)
assert output[2] == ToolCallFinalResult(
assert output[3] == ToolCallFinalResult(
tool_name="search",
tool_args=DEFAULT_SEARCH_ARGS,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[4] == expected_citation
assert output[5] == OnyxAnswerPiece(
assert output[5] == expected_citation
assert output[6] == OnyxAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "

View File

@@ -79,7 +79,7 @@ def test_skip_gen_ai_answer_generation_flag(
for res in results:
print(res)
expected_count = 3 if skip_gen_ai_answer_generation else 4
expected_count = 4 if skip_gen_ai_answer_generation else 5
assert len(results) == expected_count
if not skip_gen_ai_answer_generation:
mock_llm.stream.assert_called_once()

View File

@@ -45,7 +45,7 @@ export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
className="mr-1 my-auto cursor-pointer"
onClick={() =>
router.push(
`/admin/actions/edit/${tool.id}?u=${Date.now()}`
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
)
}
/>

View File

@@ -281,7 +281,7 @@ export default function AddConnector({
return (
<Formik
initialValues={{
...createConnectorInitialValues(connector, currentCredential),
...createConnectorInitialValues(connector),
...Object.fromEntries(
connectorConfigs[connector].advanced_values.map((field) => [
field.name,

View File

@@ -148,7 +148,8 @@ export function Explorer({
clearTimeout(timeoutId);
}
if (query && query.trim() !== "") {
let doSearch = true;
if (doSearch) {
router.replace(
`/admin/documents/explorer?query=${encodeURIComponent(query)}`
);

View File

@@ -91,7 +91,7 @@ export function AgenticToggle({
>
<div className="flex items-center space-x-2 mb-3">
<h3 className="text-sm font-semibold text-neutral-900">
Agent Search
Agent Search (BETA)
</h3>
</div>
<p className="text-xs text-neutral-600 dark:text-neutral-700 mb-2">

View File

@@ -450,6 +450,7 @@ export const AIMessage = ({
)}
</>
) : null}
{toolCall &&
!TOOLS_WITH_CUSTOM_HANDLING.includes(
toolCall.tool_name
@@ -466,10 +467,12 @@ export const AIMessage = ({
isRunning={!toolCall.tool_result || !content}
/>
)}
{toolCall &&
(!files || files.length == 0) &&
toolCall.tool_name === IMAGE_GENERATION_TOOL_NAME &&
!toolCall.tool_result && <GeneratingImageDisplay />}
{toolCall &&
toolCall.tool_name === INTERNET_SEARCH_TOOL_NAME && (
<ToolRunDisplay
@@ -484,6 +487,7 @@ export const AIMessage = ({
isRunning={!toolCall.tool_result}
/>
)}
{docs && docs.length > 0 && (
<div
className={`mobile:hidden ${
@@ -518,6 +522,7 @@ export const AIMessage = ({
</div>
</div>
)}
{content || files ? (
<>
<FileDisplay files={files || []} />
@@ -969,7 +974,7 @@ export const HumanMessage = ({
</div>
) : typeof content === "string" ? (
<>
<div className="ml-auto flex items-center mr-1 mt-2 h-fit mb-auto">
<div className="ml-auto flex items-center mr-1 h-fit my-auto">
{onEdit &&
isHovered &&
!isEditing &&

View File

@@ -1,133 +0,0 @@
import { preprocessLaTeX } from "./codeUtils";
describe("preprocessLaTeX", () => {
describe("currency formatting", () => {
it("should properly escape dollar signs in text with amounts", () => {
const input =
"Maria wants to buy a new laptop that costs $1,200. She has saved $800 so far. If she saves an additional $100 each month, how many months will it take her to have enough money to buy the laptop?";
const processed = preprocessLaTeX(input);
// Should escape all dollar signs in currency amounts
expect(processed).toContain("costs \\$1,200");
expect(processed).toContain("saved \\$800");
expect(processed).toContain("additional \\$100");
expect(processed).not.toContain("costs $1,200");
});
it("should handle dollar signs with backslashes already present", () => {
const input =
"Maria wants to buy a new laptop that costs \\$1,200. She has saved \\$800 so far.";
const processed = preprocessLaTeX(input);
// Should preserve the existing escaped dollar signs
expect(processed).toContain("\\$1,200");
expect(processed).toContain("\\$800");
});
});
describe("code block handling", () => {
it("should not process dollar signs in code blocks", () => {
const input = "```plaintext\nThe total cost is $50.\n```";
const processed = preprocessLaTeX(input);
// Dollar sign in code block should remain untouched
expect(processed).toContain("The total cost is $50.");
expect(processed).not.toContain("The total cost is \\$50.");
});
it("should not process dollar signs in inline code", () => {
const input =
'Use the `printf "$%.2f" $amount` command to format currency.';
const processed = preprocessLaTeX(input);
// Dollar signs in inline code should remain untouched
expect(processed).toContain('`printf "$%.2f" $amount`');
expect(processed).not.toContain('`printf "\\$%.2f" \\$amount`');
});
it("should handle mixed content with code blocks and currency", () => {
const input =
"The cost is $100.\n\n```javascript\nconst price = '$50';\n```\n\nThe remaining balance is $50.";
const processed = preprocessLaTeX(input);
// Dollar signs outside code blocks should be escaped
expect(processed).toContain("The cost is \\$100");
expect(processed).toContain("The remaining balance is \\$50");
// Dollar sign in code block should be preserved
expect(processed).toContain("const price = '$50';");
expect(processed).not.toContain("const price = '\\$50';");
});
});
describe("LaTeX handling", () => {
it("should preserve proper LaTeX delimiters", () => {
const input =
"The formula $x^2 + y^2 = z^2$ represents the Pythagorean theorem.";
const processed = preprocessLaTeX(input);
// LaTeX delimiters should be preserved
expect(processed).toContain("$x^2 + y^2 = z^2$");
});
it("should convert LaTeX block delimiters", () => {
const input = "Consider the equation: \\[E = mc^2\\]";
const processed = preprocessLaTeX(input);
// Block LaTeX delimiters should be converted
expect(processed).toContain("$$E = mc^2$$");
expect(processed).not.toContain("\\[E = mc^2\\]");
});
it("should convert LaTeX inline delimiters", () => {
const input =
"The speed of light \\(c\\) is approximately 299,792,458 m/s.";
const processed = preprocessLaTeX(input);
// Inline LaTeX delimiters should be converted
expect(processed).toContain("$c$");
expect(processed).not.toContain("\\(c\\)");
});
});
describe("special cases", () => {
it("should handle shell variables in text", () => {
const input =
"In bash, you can access arguments with $1, $2, and use echo $HOME to print the home directory.";
const processed = preprocessLaTeX(input);
// Verify current behavior (numeric shell variables are being escaped)
expect(processed).toContain("\\$1");
expect(processed).toContain("\\$2");
// But $HOME is not escaped (non-numeric)
expect(processed).toContain("$HOME");
});
it("should handle shell commands with dollar signs", () => {
const input = "Use awk '{print $2}' to print the second column.";
const processed = preprocessLaTeX(input);
// Dollar sign in awk command should not be escaped
expect(processed).toContain("{print $2}");
expect(processed).not.toContain("{print \\$2}");
});
it("should handle Einstein's equation with mixed LaTeX and code blocks", () => {
const input =
"Sure! The equation for Einstein's mass-energy equivalence, \\(E = mc^2\\), can be written in LaTeX as follows: ```latex\nE = mc^2\n``` When rendered, it looks like this: \\[ E = mc^2 \\]";
const processed = preprocessLaTeX(input);
// LaTeX inline delimiters should be converted
expect(processed).toContain("equivalence, $E = mc^2$,");
expect(processed).not.toContain("equivalence, \\(E = mc^2\\),");
// LaTeX block delimiters should be converted
expect(processed).toContain("it looks like this: $$ E = mc^2 $$");
expect(processed).not.toContain("it looks like this: \\[ E = mc^2 \\]");
// LaTeX within code blocks should remain untouched
expect(processed).toContain("```latex\nE = mc^2\n```");
});
});
});

View File

@@ -59,82 +59,20 @@ export function extractCodeText(
return codeText || "";
}
// We must preprocess LaTeX in the LLM output to avoid improper formatting
export const preprocessLaTeX = (content: string) => {
// First detect if content is within a code block
const codeBlockRegex = /^```[\s\S]*?```$/;
const isCodeBlock = codeBlockRegex.test(content.trim());
// If the entire content is a code block, don't process LaTeX
if (isCodeBlock) {
return content;
}
// Extract code blocks and replace with placeholders
const codeBlocks: string[] = [];
const withCodeBlocksReplaced = content.replace(/```[\s\S]*?```/g, (match) => {
const placeholder = `___CODE_BLOCK_${codeBlocks.length}___`;
codeBlocks.push(match);
return placeholder;
});
// First, protect code-like expressions where $ is used for variables
const codeProtected = withCodeBlocksReplaced.replace(
/\b(\w+(?:\s*-\w+)*\s*(?:'[^']*')?)\s*\{[^}]*?\$\d+[^}]*?\}/g,
(match) => {
// Replace $ with a temporary placeholder in code contexts
return match.replace(/\$/g, "___DOLLAR_PLACEHOLDER___");
}
);
// Also protect common shell variable patterns like $1, $2, etc.
const shellProtected = codeProtected.replace(
/\b(?:print|echo|awk|sed|grep)\s+.*?\$\d+/g,
(match) => match.replace(/\$/g, "___DOLLAR_PLACEHOLDER___")
);
// Protect inline code blocks with backticks
const inlineCodeProtected = shellProtected.replace(/`[^`]+`/g, (match) => {
return match.replace(/\$/g, "___DOLLAR_PLACEHOLDER___");
});
// Process LaTeX expressions now that code is protected
// Valid LaTeX should have matching dollar signs with non-space chars surrounding content
const processedForLatex = inlineCodeProtected.replace(
/\$([^\s$][^$]*?[^\s$])\$/g,
(_, equation) => `$${equation}$`
);
// Escape currency mentions
const currencyEscaped = processedForLatex.replace(
/\$(\d+(?:\.\d*)?)/g,
(_, p1) => `\\$${p1}`
);
// Replace block-level LaTeX delimiters \[ \] with $$ $$
const blockProcessed = currencyEscaped.replace(
// 1) Replace block-level LaTeX delimiters \[ \] with $$ $$
const blockProcessedContent = content.replace(
/\\\[([\s\S]*?)\\\]/g,
(_, equation) => `$$${equation}$$`
);
// Replace inline LaTeX delimiters \( \) with $ $
const inlineProcessed = blockProcessed.replace(
// 2) Replace inline LaTeX delimiters \( \) with $ $
const inlineProcessedContent = blockProcessedContent.replace(
/\\\(([\s\S]*?)\\\)/g,
(_, equation) => `$${equation}$`
);
// Restore original dollar signs in code contexts
const restoredDollars = inlineProcessed.replace(
/___DOLLAR_PLACEHOLDER___/g,
"$"
);
// Restore code blocks
const restoredCodeBlocks = restoredDollars.replace(
/___CODE_BLOCK_(\d+)___/g,
(_, index) => codeBlocks[parseInt(index)]
);
return restoredCodeBlocks;
return inlineProcessedContent;
};

View File

@@ -34,12 +34,11 @@
/* -------------------------------------------------------
* 2. Keep special, custom, or near-duplicate background
* ------------------------------------------------------- */
--background: #fefcfa; /* slightly off-white */
--background-50: #fffdfb; /* a little lighter than background but not quite white */
--background: #fefcfa; /* slightly off-white, keep it */
--input-background: #fefcfa;
--input-border: #f1eee8;
--text-text: #f4f2ed;
--background-dark: #141414;
--background-dark: #e9e6e0;
--new-background: #ebe7de;
--new-background-light: #d9d1c0;
--background-chatbar: #f5f3ee;
@@ -235,7 +234,6 @@
--text-text: #1d1d1d;
--background-dark: #252525;
--background-50: #252525;
/* --new-background: #fff; */
--new-background: #2c2c2c;

View File

@@ -181,7 +181,7 @@ const SignedUpUserTable = ({
: "All Roles"}
</SelectValue>
</SelectTrigger>
<SelectContent className="bg-background-50">
<SelectContent className="bg-background">
{Object.entries(USER_ROLE_LABELS)
.filter(([role]) => role !== UserRole.EXT_PERM_USER)
.map(([role, label]) => (

View File

@@ -26,13 +26,7 @@ export const buildDocumentSummaryDisplay = (
matchHighlights: string[],
blurb: string
) => {
// if there are no match highlights, or if it's really short, just use the blurb
// this is to prevent the UI from showing something like `...` for the summary
const MIN_MATCH_HIGHLIGHT_LENGTH = 5;
if (
!matchHighlights ||
matchHighlights.length <= MIN_MATCH_HIGHLIGHT_LENGTH
) {
if (!matchHighlights || matchHighlights.length === 0) {
return blurb;
}

View File

@@ -102,7 +102,7 @@ export function UserProvider({
};
// Use the custom token refresh hook
useTokenRefresh(upToDateUser, fetchUser);
// useTokenRefresh(upToDateUser, fetchUser);
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
try {

View File

@@ -1292,8 +1292,7 @@ For example, specifying .*-support.* as a "channel" will cause the connector to
},
};
export function createConnectorInitialValues(
connector: ConfigurableSources,
currentCredential: Credential<any> | null = null
connector: ConfigurableSources
): Record<string, any> & AccessTypeGroupSelectorFormType {
const configuration = connectorConfigs[connector];
@@ -1308,16 +1307,7 @@ export function createConnectorInitialValues(
} else if (field.type === "list") {
acc[field.name] = field.default || [];
} else if (field.type === "checkbox") {
// Special case for include_files_shared_with_me when using service account
if (
field.name === "include_files_shared_with_me" &&
currentCredential &&
!currentCredential.credential_json?.google_tokens
) {
acc[field.name] = true;
} else {
acc[field.name] = field.default || false;
}
acc[field.name] = field.default || false;
} else if (field.default !== undefined) {
acc[field.name] = field.default;
}

View File

@@ -108,7 +108,6 @@ module.exports = {
"accent-background": "var(--accent-background)",
"accent-background-hovered": "var(--accent-background-hovered)",
"accent-background-selected": "var(--accent-background-selected)",
"background-50": "var(--background-50)",
"background-dark": "var(--off-white)",
"background-100": "var(--neutral-100-border-light)",
"background-125": "var(--neutral-125)",