Compare commits

..

18 Commits

Author SHA1 Message Date
pablonyx
b4f3bdf09f temporary fix for auth 2025-03-31 11:07:52 -07:00
joachim-danswer
e988c13e1d Additional logging for the path from Search Results to LLM Context (#4387)
* added logging

* nit

* nit
2025-03-31 00:38:43 +00:00
pablonyx
dc18d53133 Improve multi tenant anonymous user interaction (#3857)
* cleaner handling

* k

* k

* address nits

* fix typing
2025-03-31 00:33:32 +00:00
evan-danswer
a1cef389aa fallback to ignoring unicode chars when huggingface tokenizer fails (#4394) 2025-03-30 23:45:20 +00:00
pablonyx
db8d6ce538 formatting (#4316) 2025-03-30 23:43:17 +00:00
pablonyx
e8370dcb24 Update refresh conditional (#4375)
* update refresh conditional

* k
2025-03-30 17:28:35 -07:00
pablonyx
9951fe13ba Fix image input processing without LLMs (#4390)
* quick fix

* quick fix

* Revert "quick fix"

This reverts commit 906b29bd9b.

* nit
2025-03-30 19:28:49 +00:00
evan-danswer
56f8ab927b Contextual Retrieval (#4029)
* contextual rag implementation

* WIP

* indexing test fix

* workaround for chunking errors, WIP on fixing massive memory cost

* mypy and test fixes

* reformatting

* fixed rebase
2025-03-30 18:49:09 +00:00
rkuo-danswer
cb5bbd3812 Feature/mit integration tests (#4299)
* new mit integration test template

* edit

* fix problem with ACL type tags and MIT testing for test_connector_deletion

* fix test_connector_deletion_for_overlapping_connectors

* disable some enterprise only tests in MIT version

* disable a bunch of user group / curator tests in MIT version

* wire off more tests

* typo fix

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-30 02:41:08 +00:00
Yuhong Sun
742d29e504 Remove BETA 2025-03-29 15:38:46 -07:00
SubashMohan
ecc155d082 fix: ensure base_url ends with a trailing slash (#4388) 2025-03-29 14:34:30 -07:00
pablonyx
0857e4809d fix background color 2025-03-28 16:33:30 -07:00
Chris Weaver
22e00a1f5c Fix duplicate docs (#4378)
* Initial

* Fix duplicate docs

* Add tests

* Switch to list comprehension

* Fix test
2025-03-28 22:25:26 +00:00
Chris Weaver
0d0588a0c1 Remove OnyxContext (#4376)
* Remove OnyxContext

* Fix UT

* Fix tests v2
2025-03-28 12:39:51 -07:00
rkuo-danswer
aab777f844 Bugfix/acl prefix (#4377)
* fix acl prefixing

* increase timeout a tad

* block access to init'ing DocumentAccess directly, fix test to work with ee/MIT

* fix env var checks

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-03-28 05:52:35 +00:00
pablonyx
babbe7689a k (#4380) 2025-03-28 02:23:45 +00:00
evan-danswer
a123661c92 fixed shared folder issue (#4371)
* fixed shared folder issue

* fix existing tests

* default allow files shared with me for service account
2025-03-27 23:39:52 +00:00
pablonyx
c554889baf Fix actions link (#4374) 2025-03-27 16:39:35 -07:00
110 changed files with 2059 additions and 547 deletions

View File

@@ -0,0 +1,209 @@
name: Run MIT Integration Tests v2
concurrency:
group: Run-MIT-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=32cpu-linux-x64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
# tag every docker image with "test" so that we can spin up the correct set
# of images during testing
# We don't need to build the Web Docker image since it's not yet used
# in the integration tests. We have a separate action to verify that it builds
# successfully.
- name: Pull Web Docker image
run: |
docker pull onyxdotapp/onyx-web-server:latest
docker tag onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:test
# we use the runs-on cache for docker builds
# in conjunction with runs-on runners, it has better speed and unlimited caching
# https://runs-on.com/caching/s3-cache-for-github-actions/
# https://runs-on.com/caching/docker/
# https://github.com/moby/buildkit#s3-cache-experimental
# images are built and run locally for testing purposes. Not pushed.
- name: Build Backend Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-backend:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/backend/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build Model Server Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/Dockerfile.model_server
platforms: linux/amd64
tags: onyxdotapp/onyx-model-server:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/model-server/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: onyxdotapp/onyx-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
# NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections
- name: Start Docker containers
run: |
cd deployment/docker_compose
AUTH_TYPE=basic \
POSTGRES_POOL_PRE_PING=true \
POSTGRES_USE_NULL_POOL=true \
REQUIRE_EMAIL_VERIFICATION=false \
DISABLE_TELEMETRY=true \
IMAGE_TAG=test \
INTEGRATION_TESTS_MODE=true \
docker compose -f docker-compose.dev.yml -p onyx-stack up -d
id: start_docker
- name: Wait for service to be ready
run: |
echo "Starting wait-for-service script..."
docker logs -f onyx-stack-api_server-1 &
start_time=$(date +%s)
timeout=300 # 5 minutes in seconds
while true; do
current_time=$(date +%s)
elapsed_time=$((current_time - start_time))
if [ $elapsed_time -ge $timeout ]; then
echo "Timeout reached. Service did not become ready in 5 minutes."
exit 1
fi
# Use curl with error handling to ignore specific exit code 56
response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error")
if [ "$response" = "200" ]; then
echo "Service is ready!"
break
elif [ "$response" = "curl_error" ]; then
echo "Curl encountered an error, possibly exit code 56. Continuing to retry..."
else
echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..."
fi
sleep 5
done
echo "Finished waiting for service."
- name: Start Mock Services
run: |
cd backend/tests/integration/mock_services
docker compose -f docker-compose.mock-it-services.yml \
-p mock-it-services-stack up -d
# NOTE: Use pre-ping/null to reduce flakiness due to dropped connections
- name: Run Standard Integration Tests
run: |
echo "Running integration tests..."
docker run --rm --network onyx-stack_default \
--name test-runner \
-e POSTGRES_HOST=relational_db \
-e POSTGRES_USER=postgres \
-e POSTGRES_PASSWORD=password \
-e POSTGRES_DB=postgres \
-e POSTGRES_POOL_PRE_PING=true \
-e POSTGRES_USE_NULL_POOL=true \
-e VESPA_HOST=index \
-e REDIS_HOST=cache \
-e API_SERVER_HOST=api_server \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
onyxdotapp/onyx-integration:test \
/app/tests/integration/tests \
/app/tests/integration/connector_job_tests
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
# ------------------------------------------------------------
# Always gather logs BEFORE "down":
- name: Dump API server logs
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true
- name: Dump all-container logs (optional)
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true
- name: Upload logs
if: always()
uses: actions/upload-artifact@v4
with:
name: docker-all-logs
path: ${{ github.workspace }}/docker-compose.log
# ------------------------------------------------------------
- name: Stop Docker containers
if: always()
run: |
cd deployment/docker_compose
docker compose -f docker-compose.dev.yml -p onyx-stack down -v

View File

@@ -0,0 +1,50 @@
"""enable contextual retrieval
Revision ID: 8e1ac4f39a9f
Revises: 3781a5eb12cb
Create Date: 2024-12-20 13:29:09.918661
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "8e1ac4f39a9f"
down_revision = "3781a5eb12cb"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"search_settings",
sa.Column(
"enable_contextual_rag",
sa.Boolean(),
nullable=False,
server_default="false",
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_name",
sa.String(),
nullable=True,
),
)
op.add_column(
"search_settings",
sa.Column(
"contextual_rag_llm_provider",
sa.String(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("search_settings", "enable_contextual_rag")
op.drop_column("search_settings", "contextual_rag_llm_name")
op.drop_column("search_settings", "contextual_rag_llm_provider")

View File

@@ -93,12 +93,12 @@ def _get_access_for_documents(
)
# To avoid collisions of group namings between connectors, they need to be prefixed
access_map[document_id] = DocumentAccess(
user_emails=non_ee_access.user_emails,
user_groups=set(user_group_info.get(document_id, [])),
access_map[document_id] = DocumentAccess.build(
user_emails=list(non_ee_access.user_emails),
user_groups=user_group_info.get(document_id, []),
is_public=is_public_anywhere,
external_user_emails=ext_u_emails,
external_user_group_ids=ext_u_groups,
external_user_emails=list(ext_u_emails),
external_user_group_ids=list(ext_u_groups),
)
return access_map

View File

@@ -2,7 +2,6 @@ from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import AllCitations
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import QADocsResponse
from onyx.chat.models import StreamingError
from onyx.chat.process_message import ChatPacketStream
@@ -32,8 +31,6 @@ def gather_stream_for_answer_api(
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
elif isinstance(packet, AllCitations):
response.citations = packet.citations
elif isinstance(packet, OnyxContexts):
response.contexts = packet
if answer:
response.answer = answer

View File

@@ -44,7 +44,7 @@ async def _get_tenant_id_from_request(
Attempt to extract tenant_id from:
1) The API key header
2) The Redis-based token (stored in Cookie: fastapiusersauth)
3) Reset token cookie
3) The anonymous user cookie
Fallback: POSTGRES_DEFAULT_SCHEMA
"""
# Check for API key
@@ -52,41 +52,55 @@ async def _get_tenant_id_from_request(
if tenant_id is not None:
return tenant_id
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(anonymous_user_cookie)
return anonymous_user_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
try:
# Look up token data in Redis
token_data = await retrieve_auth_token_data_from_redis(request)
if not token_data:
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
if token_data:
tenant_id_from_payload = token_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
)
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
tenant_id_from_payload = token_data.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else None
)
# Since token_data.get() can return None, ensure we have a string
tenant_id = (
str(tenant_id_from_payload)
if tenant_id_from_payload is not None
else POSTGRES_DEFAULT_SCHEMA
if tenant_id and not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
# Check for anonymous user cookie
anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME)
if anonymous_user_cookie:
try:
anonymous_user_data = decode_anonymous_user_jwt_token(
anonymous_user_cookie
)
tenant_id = anonymous_user_data.get(
"tenant_id", POSTGRES_DEFAULT_SCHEMA
)
if not tenant_id or not is_valid_schema_name(tenant_id):
raise HTTPException(
status_code=400, detail="Invalid tenant ID format"
)
return tenant_id
except Exception as e:
logger.error(f"Error decoding anonymous user cookie: {str(e)}")
# Continue and attempt to authenticate
logger.debug(
"Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA"
)
if not is_valid_schema_name(tenant_id):
raise HTTPException(status_code=400, detail="Invalid tenant ID format")
# Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema
# The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA,
# so we maintain consistency by returning it here when no valid tenant is found.
return POSTGRES_DEFAULT_SCHEMA
except Exception as e:
logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}")

View File

@@ -14,7 +14,6 @@ from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
)
from ee.onyx.server.query_and_chat.models import ChatBasicResponse
from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
@@ -56,25 +55,6 @@ logger = setup_logger()
router = APIRouter(prefix="/chat")
def _translate_doc_response_to_simple_doc(
doc_response: QADocsResponse,
) -> list[SimpleDoc]:
return [
SimpleDoc(
id=doc.document_id,
semantic_identifier=doc.semantic_identifier,
link=doc.link,
blurb=doc.blurb,
match_highlights=[
highlight for highlight in doc.match_highlights if highlight
],
source_type=doc.source_type,
metadata=doc.metadata,
)
for doc in doc_response.top_documents
]
def _get_final_context_doc_indices(
final_context_docs: list[LlmDoc] | None,
top_docs: list[SavedSearchDoc] | None,
@@ -111,9 +91,6 @@ def _convert_packet_stream_to_response(
elif isinstance(packet, QADocsResponse):
response.top_documents = packet.top_documents
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)

View File

@@ -8,7 +8,6 @@ from pydantic import model_validator
from ee.onyx.server.manage.models import StandardAnswer
from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
@@ -164,8 +163,6 @@ class ChatBasicResponse(BaseModel):
cited_documents: dict[int, str] | None = None
# FOR BACKWARDS COMPATIBILITY
# TODO: deprecate both of these
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
@@ -220,4 +217,3 @@ class OneShotQAResponse(BaseModel):
llm_selected_doc_indices: list[int] | None = None
error_msg: str | None = None
chat_message_id: int | None = None
contexts: OnyxContexts | None = None

View File

@@ -94,6 +94,7 @@ async def get_or_provision_tenant(
# Notify control plane if we have created / assigned a new tenant
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
return tenant_id
except Exception as e:

View File

@@ -18,7 +18,7 @@ def _get_access_for_document(
document_id=document_id,
)
return DocumentAccess.build(
doc_access = DocumentAccess.build(
user_emails=info[1] if info and info[1] else [],
user_groups=[],
external_user_emails=[],
@@ -26,6 +26,8 @@ def _get_access_for_document(
is_public=info[2] if info else False,
)
return doc_access
def get_access_for_document(
document_id: str,
@@ -38,12 +40,12 @@ def get_access_for_document(
def get_null_document_access() -> DocumentAccess:
return DocumentAccess(
user_emails=set(),
user_groups=set(),
return DocumentAccess.build(
user_emails=[],
user_groups=[],
is_public=False,
external_user_emails=set(),
external_user_group_ids=set(),
external_user_emails=[],
external_user_group_ids=[],
)
@@ -56,18 +58,18 @@ def _get_access_for_documents(
document_ids=document_ids,
)
doc_access = {
document_id: DocumentAccess(
user_emails=set([email for email in user_emails if email]),
document_id: DocumentAccess.build(
user_emails=[email for email in user_emails if email],
# MIT version will wipe all groups and external groups on update
user_groups=set(),
user_groups=[],
is_public=is_public,
external_user_emails=set(),
external_user_group_ids=set(),
external_user_emails=[],
external_user_group_ids=[],
)
for document_id, user_emails, is_public in document_access_info
}
# Sometimes the document has not be indexed by the indexing job yet, in those cases
# Sometimes the document has not been indexed by the indexing job yet, in those cases
# the document does not exist and so we use least permissive. Specifically the EE version
# checks the MIT version permissions and creates a superset. This ensures that this flow
# does not fail even if the Document has not yet been indexed.

View File

@@ -56,34 +56,46 @@ class DocExternalAccess:
)
@dataclass(frozen=True)
@dataclass(frozen=True, init=False)
class DocumentAccess(ExternalAccess):
# User emails for Onyx users, None indicates admin
user_emails: set[str | None]
# Names of user groups associated with this document
user_groups: set[str]
def to_acl(self) -> set[str]:
return set(
[
prefix_user_email(user_email)
for user_email in self.user_emails
if user_email
]
+ [prefix_user_group(group_name) for group_name in self.user_groups]
+ [
prefix_user_email(user_email)
for user_email in self.external_user_emails
]
+ [
# The group names are already prefixed by the source type
# This adds an additional prefix of "external_group:"
prefix_external_group(group_name)
for group_name in self.external_user_group_ids
]
+ ([PUBLIC_DOC_PAT] if self.is_public else [])
external_user_emails: set[str]
external_user_group_ids: set[str]
is_public: bool
def __init__(self) -> None:
raise TypeError(
"Use `DocumentAccess.build(...)` instead of creating an instance directly."
)
def to_acl(self) -> set[str]:
# the acl's emitted by this function are prefixed by type
# to get the native objects, access the member variables directly
acl_set: set[str] = set()
for user_email in self.user_emails:
if user_email:
acl_set.add(prefix_user_email(user_email))
for group_name in self.user_groups:
acl_set.add(prefix_user_group(group_name))
for external_user_email in self.external_user_emails:
acl_set.add(prefix_user_email(external_user_email))
for external_group_id in self.external_user_group_ids:
acl_set.add(prefix_external_group(external_group_id))
if self.is_public:
acl_set.add(PUBLIC_DOC_PAT)
return acl_set
@classmethod
def build(
cls,
@@ -93,29 +105,32 @@ class DocumentAccess(ExternalAccess):
external_user_group_ids: list[str],
is_public: bool,
) -> "DocumentAccess":
return cls(
external_user_emails={
prefix_user_email(external_email)
for external_email in external_user_emails
},
external_user_group_ids={
prefix_external_group(external_group_id)
for external_group_id in external_user_group_ids
},
user_emails={
prefix_user_email(user_email)
for user_email in user_emails
if user_email
},
user_groups=set(user_groups),
is_public=is_public,
"""Don't prefix incoming data wth acl type, prefix on read from to_acl!"""
obj = object.__new__(cls)
object.__setattr__(
obj, "user_emails", {user_email for user_email in user_emails if user_email}
)
object.__setattr__(obj, "user_groups", set(user_groups))
object.__setattr__(
obj,
"external_user_emails",
{external_email for external_email in external_user_emails},
)
object.__setattr__(
obj,
"external_user_group_ids",
{external_group_id for external_group_id in external_user_group_ids},
)
object.__setattr__(obj, "is_public", is_public)
return obj
default_public_access = DocumentAccess(
external_user_emails=set(),
external_user_group_ids=set(),
user_emails=set(),
user_groups=set(),
default_public_access = DocumentAccess.build(
external_user_emails=[],
external_user_group_ids=[],
user_emails=[],
user_groups=[],
is_public=True,
)

View File

@@ -7,7 +7,6 @@ from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.stream_processing.answer_response_handler import AnswerResponseHandler
from onyx.chat.stream_processing.answer_response_handler import CitationResponseHandler
from onyx.chat.stream_processing.answer_response_handler import (
@@ -24,7 +23,7 @@ def process_llm_stream(
should_stream_answer: bool,
writer: StreamWriter,
final_search_results: list[LlmDoc] | None = None,
displayed_search_results: list[OnyxContext] | list[LlmDoc] | None = None,
displayed_search_results: list[LlmDoc] | None = None,
) -> AIMessageChunk:
tool_call_chunk = AIMessageChunk(content="")

View File

@@ -156,7 +156,6 @@ def generate_initial_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -183,7 +183,6 @@ def generate_validate_refined_answer(
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -57,7 +57,6 @@ def format_results(
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,

View File

@@ -13,9 +13,7 @@ from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import section_to_llm_doc
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -59,9 +57,7 @@ def basic_use_tool_response(
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
initial_search_results.append(section_to_llm_doc(section))
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -360,7 +360,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
reason="Password must contain at least one special character from the following set: "
f"{PASSWORD_SPECIAL_CHARS}."
)
return
async def oauth_callback(

View File

@@ -389,6 +389,8 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
credential_id_to_delete: int | None = None
connector_id_to_delete: int | None = None
if not cc_pair:
task_logger.warning(
f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}"
@@ -443,26 +445,35 @@ def monitor_connector_deletion_taskset(
db_session=db_session,
)
# Store IDs before potentially expiring cc_pair
connector_id_to_delete = cc_pair.connector_id
credential_id_to_delete = cc_pair.credential_id
# Explicitly delete document by connector credential pair records before deleting the connector
# This is needed because connector_id is a primary key in that table and cascading deletes won't work
delete_all_documents_by_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# Flush to ensure document deletion happens before connector deletion
db_session.flush()
# Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state
# related to the deleted DocumentByConnectorCredentialPair during commit
db_session.expire(cc_pair)
# finally, delete the cc-pair
delete_connector_credential_pair__no_commit(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
connector_id=connector_id_to_delete,
credential_id=credential_id_to_delete,
)
# if there are no credentials left, delete the connector
connector = fetch_connector_by_id(
db_session=db_session,
connector_id=cc_pair.connector_id,
connector_id=connector_id_to_delete,
)
if not connector or not len(connector.credentials):
task_logger.info(
@@ -495,15 +506,15 @@ def monitor_connector_deletion_taskset(
task_logger.exception(
f"Connector deletion exceptioned: "
f"cc_pair={cc_pair_id} connector={cc_pair.connector_id} credential={cc_pair.credential_id}"
f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}"
)
raise e
task_logger.info(
f"Connector deletion succeeded: "
f"cc_pair={cc_pair_id} "
f"connector={cc_pair.connector_id} "
f"credential={cc_pair.credential_id} "
f"connector={connector_id_to_delete} "
f"credential={credential_id_to_delete} "
f"docs_deleted={fence_data.num_tasks}"
)
@@ -553,7 +564,7 @@ def validate_connector_deletion_fences(
def validate_connector_deletion_fence(
tenant_id: str,
key_bytes: bytes,
queued_tasks: set[str],
queued_upsert_tasks: set[str],
r: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
@@ -640,7 +651,7 @@ def validate_connector_deletion_fence(
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
if member_str in queued_upsert_tasks:
continue
tasks_not_in_celery += 1

View File

@@ -194,17 +194,6 @@ class StreamingError(BaseModel):
stack_trace: str | None = None
class OnyxContext(BaseModel):
content: str
document_id: str
semantic_identifier: str
blurb: str
class OnyxContexts(BaseModel):
contexts: list[OnyxContext]
class OnyxAnswer(BaseModel):
answer: str | None
@@ -270,7 +259,6 @@ class PersonaOverrideConfig(BaseModel):
AnswerQuestionPossibleReturn = (
OnyxAnswerPiece
| CitationInfo
| OnyxContexts
| FileChatDisplay
| CustomToolResponse
| StreamingError

View File

@@ -29,7 +29,6 @@ from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import MessageResponseIDInfo
from onyx.chat.models import MessageSpecificCitations
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
@@ -131,7 +130,6 @@ from onyx.tools.tool_implementations.internet_search.internet_search_tool import
from onyx.tools.tool_implementations.search.search_tool import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
@@ -300,7 +298,6 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| OnyxContexts
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -919,8 +916,6 @@ def stream_chat_message_objects(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
yield cast(OnyxContexts, packet.response)
elif isinstance(packet, StreamStopInfo):
if packet.stop_reason == StreamStopReason.FINISHED:

View File

@@ -153,6 +153,8 @@ def _apply_pruning(
# remove docs that are explicitly marked as not for QA
sections = _remove_sections_to_ignore(sections=sections)
section_idx_token_count: dict[int, int] = {}
final_section_ind = None
total_tokens = 0
for ind, section in enumerate(sections):
@@ -202,10 +204,20 @@ def _apply_pruning(
section_token_count = DOC_EMBEDDING_CONTEXT_SIZE
total_tokens += section_token_count
section_idx_token_count[ind] = section_token_count
if total_tokens > token_limit:
final_section_ind = ind
break
try:
logger.debug(f"Number of documents after pruning: {ind + 1}")
logger.debug("Number of tokens per document (pruned):")
for x, y in section_idx_token_count.items():
logger.debug(f"{x + 1}: {y}")
except Exception as e:
logger.error(f"Error logging prune statistics: {e}")
if final_section_ind is not None:
if is_manually_selected_docs or use_sections:
if final_section_ind != len(sections) - 1:
@@ -301,6 +313,10 @@ def prune_sections(
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
@@ -358,6 +374,26 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
reverse=True,
)
try:
num_original_sections = len(sections)
num_original_document_ids = len(
set([section.center_chunk.document_id for section in sections])
)
num_merged_sections = len(new_sections)
num_merged_document_ids = len(
set([section.center_chunk.document_id for section in new_sections])
)
logger.debug(
f"Merged {num_original_sections} sections from {num_original_document_ids} documents "
f"into {num_merged_sections} new sections in {num_merged_document_ids} documents"
)
logger.debug("Number of chunks per document (new ranking):")
for x, y in enumerate(new_sections):
logger.debug(f"{x + 1}: {len(y.chunks)}")
except Exception as e:
logger.error(f"Error logging merge statistics: {e}")
return new_sections

View File

@@ -3,7 +3,6 @@ from collections.abc import Sequence
from pydantic import BaseModel
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceChunk
@@ -12,7 +11,7 @@ class DocumentIdOrderMapping(BaseModel):
def map_document_id_order(
chunks: Sequence[InferenceChunk | LlmDoc | OnyxContext], one_indexed: bool = True
chunks: Sequence[InferenceChunk | LlmDoc], one_indexed: bool = True
) -> DocumentIdOrderMapping:
order_mapping = {}
current = 1 if one_indexed else 0

View File

@@ -495,6 +495,11 @@ NUM_SECONDARY_INDEXING_WORKERS = int(
ENABLE_MULTIPASS_INDEXING = (
os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true"
)
# Enable contextual retrieval
ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true"
DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini"
DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI"
# Finer grained chunking for more detail retention
# Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE
# tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end
@@ -536,6 +541,17 @@ MAX_FILE_SIZE_BYTES = int(
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
) # 2GB in bytes
# Use document summary for contextual rag
USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true"
# Use chunk summary for contextual rag
USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true"
# Average summary embeddings for contextual rag (not yet implemented)
AVERAGE_SUMMARY_EMBEDDINGS = (
os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true"
)
MAX_TOKENS_FOR_FULL_INCLUSION = 4096
#####
# Miscellaneous
#####

View File

@@ -28,7 +28,9 @@ from onyx.connectors.google_drive.doc_conversion import (
)
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import get_all_files_in_my_drive
from onyx.connectors.google_drive.file_retrieval import (
get_all_files_in_my_drive_and_shared,
)
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
from onyx.connectors.google_drive.models import DriveRetrievalStage
@@ -86,13 +88,18 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
allow_images: bool,
size_threshold: int,
retriever_email: str,
file: dict[str, Any],
) -> Document | ConnectorFailure | None:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
# We used to always get the user email from the file owners when available,
# but this was causing issues with shared folders where the owner was not included in the service account
# now we use the email of the account that successfully listed the file. Leaving this in case we end up
# wanting to retry with file owners and/or admin email at some point.
# user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_email = retriever_email
# Only construct these services when needed
user_drive_service = lazy_eval(
lambda: get_drive_service(creds, user_email=user_email)
@@ -450,10 +457,11 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.info(f"Getting all files in my drive as '{user_email}'")
yield from add_retrieval_info(
get_all_files_in_my_drive(
get_all_files_in_my_drive_and_shared(
service=drive_service,
update_traversed_ids_func=self._update_traversed_parent_ids,
is_slim=is_slim,
include_shared_with_me=self.include_files_shared_with_me,
start=curr_stage.completed_until if resuming else start,
end=end,
),
@@ -916,20 +924,28 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
self.allow_images,
self.size_threshold,
)
# Fetch files in batches
batches_complete = 0
files_batch: list[GoogleDriveFileType] = []
files_batch: list[RetrievedDriveFile] = []
def _yield_batch(
files_batch: list[GoogleDriveFileType],
files_batch: list[RetrievedDriveFile],
) -> Iterator[Document | ConnectorFailure]:
nonlocal batches_complete
# Process the batch using run_functions_tuples_in_parallel
func_with_args = [(convert_func, (file,)) for file in files_batch]
func_with_args = [
(
convert_func,
(
file.user_email,
file.drive_file,
),
)
for file in files_batch
]
results = cast(
list[Document | ConnectorFailure | None],
run_functions_tuples_in_parallel(func_with_args, max_workers=8),
@@ -967,7 +983,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
)
continue
files_batch.append(retrieved_file.drive_file)
files_batch.append(retrieved_file)
if len(files_batch) < self.batch_size:
continue

View File

@@ -30,6 +30,7 @@ from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.lazy import lazy_eval
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -76,6 +77,26 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
return is_valid_image_type(mime_type)
def download_request(service: GoogleDriveService, file_id: str) -> bytes:
"""
Download the file from Google Drive.
"""
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_id}")
return bytes()
return response
def _download_and_extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
@@ -87,35 +108,17 @@ def _download_and_extract_sections_basic(
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
try:
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# skip images if not explicitly enabled
if not allow_images and is_gdrive_image_mime_type(mime_type):
return []
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
@@ -124,88 +127,97 @@ def _download_and_extract_sections_basic(
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
response_call = lazy_eval(lambda: download_request(service, file_id))
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
# Process based on mime type
if mime_type == "text/plain":
text = response_call().decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
elif (
mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response_call()))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response_call(),
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call()))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
else:
# For unsupported file types, try to extract text
if mime_type in [
"application/vnd.google-apps.video",
"application/vnd.google-apps.audio",
"application/zip",
]:
return []
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response_call()), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
def convert_drive_item_to_document(

View File

@@ -214,10 +214,11 @@ def get_files_in_shared_drive(
yield file
def get_all_files_in_my_drive(
def get_all_files_in_my_drive_and_shared(
service: GoogleDriveService,
update_traversed_ids_func: Callable,
is_slim: bool,
include_shared_with_me: bool,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> Iterator[GoogleDriveFileType]:
@@ -229,7 +230,8 @@ def get_all_files_in_my_drive(
# Get all folders being queried and add them to the traversed set
folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'"
folder_query += " and trashed = false"
folder_query += " and 'me' in owners"
if not include_shared_with_me:
folder_query += " and 'me' in owners"
found_folders = False
for file in execute_paginated_retrieval(
retrieval_function=service.files().list,
@@ -246,7 +248,8 @@ def get_all_files_in_my_drive(
# Then get the files
file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'"
file_query += " and trashed = false"
file_query += " and 'me' in owners"
if not include_shared_with_me:
file_query += " and 'me' in owners"
file_query += _generate_time_range_filter(start, end)
yield from execute_paginated_retrieval(
retrieval_function=service.files().list,

View File

@@ -75,7 +75,7 @@ class HighspotClient:
self.key = key
self.secret = secret
self.base_url = base_url
self.base_url = base_url.rstrip("/") + "/"
self.timeout = timeout
# Set up session with retry logic

View File

@@ -163,6 +163,9 @@ class DocumentBase(BaseModel):
attributes.append(k + INDEX_SEPARATOR + v)
return attributes
def get_text_content(self) -> str:
return " ".join([section.text for section in self.sections if section.text])
class Document(DocumentBase):
"""Used for Onyx ingestion api, the ID is required"""

View File

@@ -60,7 +60,7 @@ class SearchSettingsCreationRequest(InferenceSettings, IndexingSetting):
inference_settings = InferenceSettings.from_db_model(search_settings)
indexing_setting = IndexingSetting.from_db_model(search_settings)
return cls(**inference_settings.dict(), **indexing_setting.dict())
return cls(**inference_settings.model_dump(), **indexing_setting.model_dump())
class SavedSearchSettings(InferenceSettings, IndexingSetting):
@@ -80,6 +80,9 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
reduced_dimension=search_settings.reduced_dimension,
# Whether switching to this model requires re-indexing
background_reindex_enabled=search_settings.background_reindex_enabled,
enable_contextual_rag=search_settings.enable_contextual_rag,
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
@@ -218,6 +221,8 @@ class InferenceChunk(BaseChunk):
# to specify that a set of words should be highlighted. For example:
# ["<hi>the</hi> <hi>answer</hi> is 42", "he couldn't find an <hi>answer</hi>"]
match_highlights: list[str]
doc_summary: str
chunk_context: str
# when the doc was last updated
updated_at: datetime | None

View File

@@ -339,6 +339,12 @@ class SearchPipeline:
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def merged_retrieved_sections(self) -> list[InferenceSection]:
"""Should be used to display in the UI in order to prevent displaying
multiple sections for the same document as separate "documents"."""
return _merge_sections(sections=self.retrieved_sections)
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -415,6 +421,10 @@ class SearchPipeline:
raise ValueError(
"Basic search evaluation operation called while DISABLE_LLM_DOC_RELEVANCE is enabled."
)
# NOTE: final_context_sections must be accessed before accessing self._postprocessing_generator
# since the property sets the generator. DO NOT REMOVE.
_ = self.final_context_sections
self._section_relevance = next(
cast(
Iterator[list[SectionRelevancePiece]],

View File

@@ -196,9 +196,21 @@ def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk
RETURN_SEPARATOR
)
def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str:
# remove document summary
if chunk.content.startswith(chunk.doc_summary):
chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip()
# remove chunk context
if chunk.content.endswith(chunk.chunk_context):
chunk.content = chunk.content[
: len(chunk.content) - len(chunk.chunk_context)
].rstrip()
return chunk.content
for chunk in chunks:
chunk.content = _remove_title(chunk)
chunk.content = _remove_metadata_suffix(chunk)
chunk.content = _remove_contextual_rag(chunk)
return [chunk.to_inference_chunk() for chunk in chunks]

View File

@@ -791,6 +791,15 @@ class SearchSettings(Base):
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
# Contextual RAG
enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False)
# Contextual RAG LLM
contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True)
contextual_rag_llm_provider: Mapped[str | None] = mapped_column(
String, nullable=True
)
multilingual_expansion: Mapped[list[str]] = mapped_column(
postgresql.ARRAY(String), default=[]
)

View File

@@ -62,6 +62,9 @@ def create_search_settings(
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
enable_contextual_rag=search_settings.enable_contextual_rag,
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
@@ -319,6 +322,7 @@ def get_old_default_embedding_model() -> IndexingSetting:
passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""),
index_name="danswer_chunk",
multipass_indexing=False,
enable_contextual_rag=False,
api_url=None,
)
@@ -333,5 +337,6 @@ def get_new_default_embedding_model() -> IndexingSetting:
passage_prefix=ASYM_PASSAGE_PREFIX,
index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}",
multipass_indexing=False,
enable_contextual_rag=False,
api_url=None,
)

View File

@@ -98,6 +98,12 @@ schema DANSWER_CHUNK_NAME {
field metadata type string {
indexing: summary | attribute
}
field chunk_context type string {
indexing: summary | attribute
}
field doc_summary type string {
indexing: summary | attribute
}
field metadata_suffix type string {
indexing: summary | attribute
}

View File

@@ -24,9 +24,11 @@ from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import CONTENT
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOC_SUMMARY
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
@@ -126,7 +128,8 @@ def _vespa_hit_to_inference_chunk(
return InferenceChunkUncleaned(
chunk_id=fields[CHUNK_ID],
blurb=fields.get(BLURB, ""), # Unused
content=fields[CONTENT], # Includes extra title prefix and metadata suffix
content=fields[CONTENT], # Includes extra title prefix and metadata suffix;
# also sometimes context for contextual rag
source_links=source_links_dict or {0: ""},
section_continuation=fields[SECTION_CONTINUATION],
document_id=fields[DOCUMENT_ID],
@@ -143,6 +146,8 @@ def _vespa_hit_to_inference_chunk(
large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []),
metadata=metadata,
metadata_suffix=fields.get(METADATA_SUFFIX),
doc_summary=fields.get(DOC_SUMMARY, ""),
chunk_context=fields.get(CHUNK_CONTEXT, ""),
match_highlights=match_highlights,
updated_at=updated_at,
)
@@ -345,6 +350,19 @@ def query_vespa(
filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None]
inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits]
try:
num_retrieved_inference_chunks = len(inference_chunks)
num_retrieved_document_ids = len(
set([chunk.document_id for chunk in inference_chunks])
)
logger.debug(
f"Retrieved {num_retrieved_inference_chunks} inference chunks for {num_retrieved_document_ids} documents"
)
except Exception as e:
# Debug logging only, should not fail the retrieval
logger.error(f"Error logging retrieval statistics: {e}")
# Good Debugging Spot
return inference_chunks

View File

@@ -187,7 +187,7 @@ class VespaIndex(DocumentIndex):
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
"Skipping Vespa index setup for multitenant (would wipe all indices)"
)
return None
@@ -821,30 +821,26 @@ class VespaIndex(DocumentIndex):
num_to_retrieve: int = NUM_RETURNED_HITS,
offset: int = 0,
) -> list[InferenceChunkUncleaned]:
vespa_where_clauses = build_vespa_filters(
filters, include_hidden=True, remove_trailing_and=True
vespa_where_clauses = build_vespa_filters(filters, include_hidden=True)
yql = (
YQL_BASE.format(index_name=self.index_name)
+ vespa_where_clauses
+ '({grammar: "weakAnd"}userInput(@query) '
# `({defaultIndex: "content_summary"}userInput(@query))` section is
# needed for highlighting while the N-gram highlighting is broken /
# not working as desired
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
)
yql = YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses
params: dict[str, str | int] = {
"yql": yql,
"query": query,
"hits": num_to_retrieve,
"offset": 0,
"ranking.profile": "admin_search",
"timeout": VESPA_TIMEOUT,
}
if len(query.strip()) > 0:
yql += (
' and ({grammar: "weakAnd"}userInput(@query) '
# `({defaultIndex: "content_summary"}userInput(@query))` section is
# needed for highlighting while the N-gram highlighting is broken /
# not working as desired
+ f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))'
)
params["yql"] = yql
params["query"] = query
return query_vespa(params)
# Retrieves chunk information for a document:

View File

@@ -25,9 +25,11 @@ from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_CONTEXT
from onyx.document_index.vespa_constants import CHUNK_ID
from onyx.document_index.vespa_constants import CONTENT
from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOC_SUMMARY
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
@@ -174,7 +176,7 @@ def _index_vespa_chunk(
# For the BM25 index, the keyword suffix is used, the vector is already generated with the more
# natural language representation of the metadata section
CONTENT: remove_invalid_unicode_chars(
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_keyword}"
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}"
),
# This duplication of `content` is needed for keyword highlighting
# Note that it's not exactly the same as the actual content
@@ -189,6 +191,8 @@ def _index_vespa_chunk(
# Save as a list for efficient extraction as an Attribute
METADATA_LIST: metadata_list,
METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword),
CHUNK_CONTEXT: chunk.chunk_context,
DOC_SUMMARY: chunk.doc_summary,
EMBEDDINGS: embeddings_name_vector_map,
TITLE_EMBEDDING: chunk.title_embedding,
DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at),

View File

@@ -71,6 +71,8 @@ LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids"
METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
DOC_SUMMARY = "doc_summary"
CHUNK_CONTEXT = "chunk_context"
BOOST = "boost"
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
@@ -106,6 +108,8 @@ YQL_BASE = (
f"{LARGE_CHUNK_REFERENCE_IDS}, "
f"{METADATA}, "
f"{METADATA_SUFFIX}, "
f"{DOC_SUMMARY}, "
f"{CHUNK_CONTEXT}, "
f"{CONTENT_SUMMARY} "
f"from {{index_name}} where "
)

View File

@@ -1,7 +1,10 @@
from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import LARGE_CHUNK_RATIO
from onyx.configs.app_configs import MINI_CHUNK_SIZE
from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.constants import SECTION_SEPARATOR
@@ -13,6 +16,7 @@ from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import clean_text
@@ -82,6 +86,9 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks],
mini_chunk_texts=None,
large_chunk_id=large_chunk_id,
chunk_context="",
doc_summary="",
contextual_rag_reserved_tokens=0,
)
offset = 0
@@ -120,6 +127,7 @@ class Chunker:
tokenizer: BaseTokenizer,
enable_multipass: bool = False,
enable_large_chunks: bool = False,
enable_contextual_rag: bool = False,
blurb_size: int = BLURB_SIZE,
include_metadata: bool = not SKIP_METADATA_IN_CHUNK,
chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE,
@@ -133,9 +141,20 @@ class Chunker:
self.chunk_token_limit = chunk_token_limit
self.enable_multipass = enable_multipass
self.enable_large_chunks = enable_large_chunks
self.enable_contextual_rag = enable_contextual_rag
if enable_contextual_rag:
assert (
USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY
), "Contextual RAG requires at least one of chunk summary and document summary enabled"
self.default_contextual_rag_reserved_tokens = MAX_CONTEXT_TOKENS * (
int(USE_CHUNK_SUMMARY) + int(USE_DOCUMENT_SUMMARY)
)
self.tokenizer = tokenizer
self.callback = callback
self.max_context = 0
self.prompt_tokens = 0
self.blurb_splitter = SentenceSplitter(
tokenizer=tokenizer.tokenize,
chunk_size=blurb_size,
@@ -221,30 +240,12 @@ class Chunker:
metadata_suffix_keyword=metadata_suffix_keyword,
mini_chunk_texts=self._get_mini_chunk_texts(text),
large_chunk_id=None,
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document
)
chunks_list.append(new_chunk)
def _chunk_document(
self,
document: IndexingDocument,
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
"""
Legacy method for backward compatibility.
Calls _chunk_document_with_sections with document.sections.
"""
return self._chunk_document_with_sections(
document,
document.processed_sections,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
content_token_limit,
)
def _chunk_document_with_sections(
self,
document: IndexingDocument,
@@ -264,7 +265,7 @@ class Chunker:
for section_idx, section in enumerate(sections):
# Get section text and other attributes
section_text = clean_text(section.text or "")
section_text = clean_text(str(section.text or ""))
section_link_text = section.link or ""
image_url = section.image_file_name
@@ -309,7 +310,7 @@ class Chunker:
continue
# CASE 2: Normal text section
section_token_count = len(self.tokenizer.tokenize(section_text))
section_token_count = len(self.tokenizer.encode(section_text))
# If the section is large on its own, split it separately
if section_token_count > content_token_limit:
@@ -332,8 +333,7 @@ class Chunker:
# If even the split_text is bigger than strict limit, further split
if (
STRICT_CHUNK_TOKEN_LIMIT
and len(self.tokenizer.tokenize(split_text))
> content_token_limit
and len(self.tokenizer.encode(split_text)) > content_token_limit
):
smaller_chunks = self._split_oversized_chunk(
split_text, content_token_limit
@@ -363,10 +363,10 @@ class Chunker:
continue
# If we can still fit this section into the current chunk, do so
current_token_count = len(self.tokenizer.tokenize(chunk_text))
current_token_count = len(self.tokenizer.encode(chunk_text))
current_offset = len(shared_precompare_cleanup(chunk_text))
next_section_tokens = (
len(self.tokenizer.tokenize(SECTION_SEPARATOR)) + section_token_count
len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count
)
if next_section_tokens + current_token_count <= content_token_limit:
@@ -414,7 +414,7 @@ class Chunker:
# Title prep
title = self._extract_blurb(document.get_title_for_document_index() or "")
title_prefix = title + RETURN_SEPARATOR if title else ""
title_tokens = len(self.tokenizer.tokenize(title_prefix))
title_tokens = len(self.tokenizer.encode(title_prefix))
# Metadata prep
metadata_suffix_semantic = ""
@@ -427,15 +427,50 @@ class Chunker:
) = _get_metadata_suffix_for_document_index(
document.metadata, include_separator=True
)
metadata_tokens = len(self.tokenizer.tokenize(metadata_suffix_semantic))
metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic))
# If metadata is too large, skip it in the semantic content
if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE:
metadata_suffix_semantic = ""
metadata_tokens = 0
single_chunk_fits = True
doc_token_count = 0
if self.enable_contextual_rag:
doc_content = document.get_text_content()
tokenized_doc = self.tokenizer.tokenize(doc_content)
doc_token_count = len(tokenized_doc)
# check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG
single_chunk_fits = (
doc_token_count + title_tokens + metadata_tokens
<= self.chunk_token_limit
)
# expand the size of the context used for contextual rag based on whether chunk context and doc summary are used
context_size = 0
if (
self.enable_contextual_rag
and not single_chunk_fits
and not AVERAGE_SUMMARY_EMBEDDINGS
):
context_size += self.default_contextual_rag_reserved_tokens
# Adjust content token limit to accommodate title + metadata
content_token_limit = self.chunk_token_limit - title_tokens - metadata_tokens
content_token_limit = (
self.chunk_token_limit - title_tokens - metadata_tokens - context_size
)
# first check: if there is not enough actual chunk content when including contextual rag,
# then don't do contextual rag
if content_token_limit <= CHUNK_MIN_CONTENT:
context_size = 0 # Don't do contextual RAG
# revert to previous content token limit
content_token_limit = (
self.chunk_token_limit - title_tokens - metadata_tokens
)
# If there is not enough context remaining then just index the chunk with no prefix/suffix
if content_token_limit <= CHUNK_MIN_CONTENT:
# Not enough space left, so revert to full chunk without the prefix
content_token_limit = self.chunk_token_limit
@@ -459,6 +494,9 @@ class Chunker:
large_chunks = generate_large_chunks(normal_chunks)
normal_chunks.extend(large_chunks)
for chunk in normal_chunks:
chunk.contextual_rag_reserved_tokens = context_size
return normal_chunks
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:

View File

@@ -121,7 +121,7 @@ class DefaultIndexingEmbedder(IndexingEmbedder):
if chunk.large_chunk_reference_ids:
large_chunks_present = True
chunk_text = (
f"{chunk.title_prefix}{chunk.content}{chunk.metadata_suffix_semantic}"
f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}"
) or chunk.source_document.get_title_for_document_index()
if not chunk_text:

View File

@@ -1,3 +1,4 @@
from collections import defaultdict
from collections.abc import Callable
from functools import partial
from typing import Protocol
@@ -8,7 +9,13 @@ from sqlalchemy.orm import Session
from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
@@ -36,9 +43,10 @@ from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Document as DBDocument
from onyx.db.models import IndexModelStatus
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.db.pg_file_store import read_lobj
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_active_search_settings
from onyx.db.tag import create_or_add_document_tag
from onyx.db.tag import create_or_add_document_tag_list
from onyx.document_index.document_index_utils import (
@@ -57,11 +65,24 @@ from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.factory import get_llm_for_contextual_rag
from onyx.llm.interfaces import LLM
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.natural_language_processing.utils import BaseTokenizer
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_middle
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT1
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_PROMPT2
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.timing import log_function_time
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
@@ -249,6 +270,8 @@ def index_doc_batch_with_handler(
db_session: Session,
tenant_id: str,
ignore_time_skip: bool = False,
enable_contextual_rag: bool = False,
llm: LLM | None = None,
) -> IndexingPipelineResult:
try:
index_pipeline_result = index_doc_batch(
@@ -261,6 +284,8 @@ def index_doc_batch_with_handler(
db_session=db_session,
ignore_time_skip=ignore_time_skip,
tenant_id=tenant_id,
enable_contextual_rag=enable_contextual_rag,
llm=llm,
)
except Exception as e:
# don't log the batch directly, it's too much text
@@ -439,7 +464,7 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
**document.dict(),
processed_sections=[
Section(
text=section.text if isinstance(section, TextSection) else None,
text=section.text if isinstance(section, TextSection) else "",
link=section.link,
image_file_name=section.image_file_name
if isinstance(section, ImageSection)
@@ -459,11 +484,11 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
for section in document.sections:
# For ImageSection, process and create base Section with both text and image_file_name
if isinstance(section, ImageSection):
# Default section with image path preserved
# Default section with image path preserved - ensure text is always a string
processed_section = Section(
link=section.link,
image_file_name=section.image_file_name,
text=None, # Will be populated if summarization succeeds
text="", # Initialize with empty string
)
# Try to get image summary
@@ -506,13 +531,21 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
# For TextSection, create a base Section with text and link
elif isinstance(section, TextSection):
processed_section = Section(
text=section.text, link=section.link, image_file_name=None
text=section.text or "", # Ensure text is always a string, not None
link=section.link,
image_file_name=None,
)
processed_sections.append(processed_section)
# If it's already a base Section (unlikely), just append it
# If it's already a base Section (unlikely), just append it with text validation
else:
processed_sections.append(section)
# Ensure text is always a string
processed_section = Section(
text=section.text if section.text is not None else "",
link=section.link,
image_file_name=section.image_file_name,
)
processed_sections.append(processed_section)
# Create IndexingDocument with original sections and processed_sections
indexed_document = IndexingDocument(
@@ -523,6 +556,145 @@ def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
return indexed_documents
def add_document_summaries(
chunks_by_doc: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
trunc_doc_tokens: int,
) -> list[int] | None:
"""
Adds a document summary to a list of chunks from the same document.
Returns the number of tokens in the document.
"""
doc_tokens = []
# this is value is the same for each chunk in the document; 0 indicates
# There is not enough space for contextual RAG (the chunk content
# and possibly metadata took up too much space)
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
return None
doc_tokens = tokenizer.encode(chunks_by_doc[0].source_document.get_text_content())
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_tokens, tokenizer)
summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content)
doc_summary = message_to_string(
llm.invoke(summary_prompt, max_tokens=MAX_CONTEXT_TOKENS)
)
for chunk in chunks_by_doc:
chunk.doc_summary = doc_summary
return doc_tokens
def add_chunk_summaries(
chunks_by_doc: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
trunc_doc_chunk_tokens: int,
doc_tokens: list[int] | None,
) -> None:
"""
Adds chunk summaries to the chunks grouped by document id.
Chunk summaries look at the chunk as well as the entire document (or a summary,
if the document is too long) and describe how the chunk relates to the document.
"""
# all chunks within a document have the same contextual_rag_reserved_tokens
if chunks_by_doc[0].contextual_rag_reserved_tokens == 0:
return
# use values computed in above doc summary section if available
doc_tokens = doc_tokens or tokenizer.encode(
chunks_by_doc[0].source_document.get_text_content()
)
doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_chunk_tokens, tokenizer)
# only compute doc summary if needed
doc_info = (
doc_content
if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION
else chunks_by_doc[0].doc_summary
)
if not doc_info:
# This happens if the document is too long AND document summaries are turned off
# In this case we compute a doc summary using the LLM
doc_info = message_to_string(
llm.invoke(
DOCUMENT_SUMMARY_PROMPT.format(document=doc_content),
max_tokens=MAX_CONTEXT_TOKENS,
)
)
context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info)
def assign_context(chunk: DocAwareChunk) -> None:
context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content)
try:
chunk.chunk_context = message_to_string(
llm.invoke(
context_prompt1 + context_prompt2,
max_tokens=MAX_CONTEXT_TOKENS,
)
)
except LLMRateLimitError as e:
# Erroring during chunker is undesirable, so we log the error and continue
# TODO: for v2, add robust retry logic
logger.exception(f"Rate limit adding chunk summary: {e}", exc_info=e)
chunk.chunk_context = ""
except Exception as e:
logger.exception(f"Error adding chunk summary: {e}", exc_info=e)
chunk.chunk_context = ""
run_functions_tuples_in_parallel(
[(assign_context, (chunk,)) for chunk in chunks_by_doc]
)
def add_contextual_summaries(
chunks: list[DocAwareChunk],
llm: LLM,
tokenizer: BaseTokenizer,
chunk_token_limit: int,
) -> list[DocAwareChunk]:
"""
Adds Document summary and chunk-within-document context to the chunks
based on which environment variables are set.
"""
max_context = get_max_input_tokens(
model_name=llm.config.model_name,
model_provider=llm.config.model_provider,
output_tokens=MAX_CONTEXT_TOKENS,
)
doc2chunks = defaultdict(list)
for chunk in chunks:
doc2chunks[chunk.source_document.id].append(chunk)
# The number of tokens allowed for the document when computing a document summary
trunc_doc_summary_tokens = max_context - len(
tokenizer.encode(DOCUMENT_SUMMARY_PROMPT)
)
prompt_tokens = len(
tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2)
)
# The number of tokens allowed for the document when computing a
# "chunk in context of document" summary
trunc_doc_chunk_tokens = max_context - prompt_tokens - chunk_token_limit
for chunks_by_doc in doc2chunks.values():
doc_tokens = None
if USE_DOCUMENT_SUMMARY:
doc_tokens = add_document_summaries(
chunks_by_doc, llm, tokenizer, trunc_doc_summary_tokens
)
if USE_CHUNK_SUMMARY:
add_chunk_summaries(
chunks_by_doc, llm, tokenizer, trunc_doc_chunk_tokens, doc_tokens
)
return chunks
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -534,6 +706,8 @@ def index_doc_batch(
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
tenant_id: str,
enable_contextual_rag: bool = False,
llm: LLM | None = None,
ignore_time_skip: bool = False,
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
) -> IndexingPipelineResult:
@@ -596,6 +770,20 @@ def index_doc_batch(
# a common source of failure for the indexing pipeline
chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs)
# contextual RAG
if enable_contextual_rag:
assert llm is not None, "must provide an LLM for contextual RAG"
llm_tokenizer = get_tokenizer(
model_name=llm.config.model_name,
provider_type=llm.config.model_provider,
)
# Because the chunker's tokens are different from the LLM's tokens,
# We add a fudge factor to ensure we truncate prompts to the LLM's token limit
chunks = add_contextual_summaries(
chunks, llm, llm_tokenizer, chunker.chunk_token_limit * 2
)
logger.debug("Starting embedding")
chunks_with_embeddings, embedding_failures = (
embed_chunks_with_failure_handling(
@@ -791,13 +979,33 @@ def build_indexing_pipeline(
callback: IndexingHeartbeatInterface | None = None,
) -> IndexingPipelineProtocol:
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
search_settings = get_current_search_settings(db_session)
all_search_settings = get_active_search_settings(db_session)
if (
all_search_settings.secondary
and all_search_settings.secondary.status == IndexModelStatus.FUTURE
):
search_settings = all_search_settings.secondary
else:
search_settings = all_search_settings.primary
multipass_config = get_multipass_config(search_settings)
enable_contextual_rag = (
search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG
)
llm = None
if enable_contextual_rag:
llm = get_llm_for_contextual_rag(
search_settings.contextual_rag_llm_name or DEFAULT_CONTEXTUAL_RAG_LLM_NAME,
search_settings.contextual_rag_llm_provider
or DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER,
)
chunker = chunker or Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=multipass_config.multipass_indexing,
enable_large_chunks=multipass_config.enable_large_chunks,
enable_contextual_rag=enable_contextual_rag,
# after every doc, update status in case there are a bunch of really long docs
callback=callback,
)
@@ -811,4 +1019,6 @@ def build_indexing_pipeline(
ignore_time_skip=ignore_time_skip,
db_session=db_session,
tenant_id=tenant_id,
enable_contextual_rag=enable_contextual_rag,
llm=llm,
)

View File

@@ -49,6 +49,15 @@ class DocAwareChunk(BaseChunk):
metadata_suffix_semantic: str
metadata_suffix_keyword: str
# This is the number of tokens reserved for contextual RAG
# in the chunk. doc_summary and chunk_context conbined should
# contain at most this many tokens.
contextual_rag_reserved_tokens: int
# This is the summary for the document generated for contextual RAG
doc_summary: str
# This is the context for this chunk generated for contextual RAG
chunk_context: str
mini_chunk_texts: list[str] | None
large_chunk_id: int | None
@@ -154,6 +163,9 @@ class IndexingSetting(EmbeddingModelDetail):
reduced_dimension: int | None = None
background_reindex_enabled: bool = True
enable_contextual_rag: bool
contextual_rag_llm_name: str | None = None
contextual_rag_llm_provider: str | None = None
# This disables the "model_" protected namespace for pydantic
model_config = {"protected_namespaces": ()}
@@ -178,6 +190,7 @@ class IndexingSetting(EmbeddingModelDetail):
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
background_reindex_enabled=search_settings.background_reindex_enabled,
enable_contextual_rag=search_settings.enable_contextual_rag,
)

View File

@@ -425,12 +425,12 @@ class DefaultMultiLLM(LLM):
messages=processed_prompt,
tools=tools,
tool_choice=tool_choice if tools else None,
max_tokens=max_tokens,
# streaming choice
stream=stream,
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -531,6 +531,7 @@ class DefaultMultiLLM(LLM):
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
return

View File

@@ -16,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import LLMProvider
from onyx.server.manage.llm.models import LLMProviderView
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
@@ -154,6 +155,40 @@ def get_default_llm_with_vision(
return None
def llm_from_provider(
model_name: str,
llm_provider: LLMProvider,
timeout: int | None = None,
temperature: float | None = None,
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model_name,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM:
with get_session_context_manager() as db_session:
llm_provider = fetch_llm_provider_view(db_session, model_provider)
if not llm_provider:
raise ValueError("No LLM provider with name {} found".format(model_provider))
return llm_from_provider(
model_name=model_name,
llm_provider=llm_provider,
)
def get_default_llms(
timeout: int | None = None,
temperature: float | None = None,
@@ -179,14 +214,9 @@ def get_default_llms(
raise ValueError("No fast default model name found")
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
deployment_name=llm_provider.deployment_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
return llm_from_provider(
model_name=model,
llm_provider=llm_provider,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,

View File

@@ -29,13 +29,19 @@ from litellm.exceptions import Timeout # type: ignore
from litellm.exceptions import UnprocessableEntityError # type: ignore
from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import MessageType
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.configs.model_configs import GEN_AI_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.prompts.chat_prompts import CONTEXTUAL_RAG_TOKEN_ESTIMATE
from onyx.prompts.chat_prompts import DOCUMENT_SUMMARY_TOKEN_ESTIMATE
from onyx.prompts.constants import CODE_BLOCK_PAT
from onyx.utils.b64 import get_image_type
from onyx.utils.b64 import get_image_type_from_bytes
@@ -44,6 +50,10 @@ from shared_configs.configs import LOG_LEVEL
logger = setup_logger()
MAX_CONTEXT_TOKENS = 100
ONE_MILLION = 1_000_000
CHUNKS_PER_DOC_ESTIMATE = 5
def litellm_exception_to_error_msg(
e: Exception,
@@ -416,6 +426,72 @@ def _find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | N
return None
def get_llm_contextual_cost(
llm: LLM,
) -> float:
"""
Approximate the cost of using the given LLM for indexing with Contextual RAG.
We use a precomputed estimate for the number of tokens in the contextualizing prompts,
and we assume that every chunk is maximized in terms of content and context.
We also assume that every document is maximized in terms of content, as currently if
a document is longer than a certain length, its summary is used instead of the full content.
We expect that the first assumption will overestimate more than the second one
underestimates, so this should be a fairly conservative price estimate. Also,
this does not account for the cost of documents that fit within a single chunk
which do not get contextualized.
"""
# calculate input costs
num_tokens = ONE_MILLION
num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE
# We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long
# on average.
num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION
num_input_tokens = 0
num_output_tokens = 0
if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY:
return 0
if USE_CHUNK_SUMMARY:
# Each per-chunk prompt includes:
# - The prompt tokens
# - the document tokens
# - the chunk tokens
# for each chunk, we prompt the LLM with the contextual RAG prompt
# and the full document content (or the doc summary, so this is an overestimate)
num_input_tokens += num_input_chunks * (
CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION
)
# in aggregate, each chunk content is used as a prompt input once
# so the full input size is covered
num_input_tokens += num_tokens
# A single MAX_CONTEXT_TOKENS worth of output is generated per chunk
num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS
# going over each doc once means all the tokens, plus the prompt tokens for
# the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false,
# since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true.
# So, we include this unconditionally to overestimate.
num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE
num_output_tokens += num_docs * MAX_CONTEXT_TOKENS
usd_per_prompt, usd_per_completion = litellm.cost_per_token(
model=llm.config.model_name,
prompt_tokens=num_input_tokens,
completion_tokens=num_output_tokens,
)
# Costs are in USD dollars per million tokens
return usd_per_prompt + usd_per_completion
def get_llm_max_tokens(
model_map: dict,
model_name: str,

View File

@@ -391,6 +391,11 @@ def get_application() -> FastAPI:
prefix="/auth",
)
if (
AUTH_TYPE == AuthType.CLOUD
or AUTH_TYPE == AuthType.BASIC
or AUTH_TYPE == AuthType.GOOGLE_OAUTH
):
# Add refresh token endpoint for OAuth as well
include_auth_router_with_prefix(
application,

View File

@@ -3,6 +3,8 @@ from abc import ABC
from abc import abstractmethod
from copy import copy
from tokenizers import Encoding # type: ignore
from tokenizers import Tokenizer # type: ignore
from transformers import logging as transformer_logging # type:ignore
from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
@@ -11,6 +13,8 @@ from onyx.context.search.models import InferenceChunk
from onyx.utils.logger import setup_logger
from shared_configs.enums import EmbeddingProvider
TRIM_SEP_PAT = "\n... {n} tokens removed...\n"
logger = setup_logger()
transformer_logging.set_verbosity_error()
os.environ["TOKENIZERS_PARALLELISM"] = "false"
@@ -67,16 +71,27 @@ class TiktokenTokenizer(BaseTokenizer):
class HuggingFaceTokenizer(BaseTokenizer):
def __init__(self, model_name: str):
from tokenizers import Tokenizer # type: ignore
self.encoder: Tokenizer = Tokenizer.from_pretrained(model_name)
self.encoder = Tokenizer.from_pretrained(model_name)
def _safer_encode(self, string: str) -> Encoding:
"""
Encode a string using the HuggingFaceTokenizer, but if it fails,
encode the string as ASCII and decode it back to a string. This helps
in cases where the string has weird characters like \udeb4.
"""
try:
return self.encoder.encode(string, add_special_tokens=False)
except Exception:
return self.encoder.encode(
string.encode("ascii", "ignore").decode(), add_special_tokens=False
)
def encode(self, string: str) -> list[int]:
# this returns no special tokens
return self.encoder.encode(string, add_special_tokens=False).ids
return self._safer_encode(string).ids
def tokenize(self, string: str) -> list[str]:
return self.encoder.encode(string, add_special_tokens=False).tokens
return self._safer_encode(string).tokens
def decode(self, tokens: list[int]) -> str:
return self.encoder.decode(tokens)
@@ -159,9 +174,26 @@ def tokenizer_trim_content(
content: str, desired_length: int, tokenizer: BaseTokenizer
) -> str:
tokens = tokenizer.encode(content)
if len(tokens) > desired_length:
content = tokenizer.decode(tokens[:desired_length])
return content
if len(tokens) <= desired_length:
return content
return tokenizer.decode(tokens[:desired_length])
def tokenizer_trim_middle(
tokens: list[int], desired_length: int, tokenizer: BaseTokenizer
) -> str:
if len(tokens) <= desired_length:
return tokenizer.decode(tokens)
sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length)
sep_tokens = tokenizer.encode(sep_str)
slice_size = (desired_length - len(sep_tokens)) // 2
assert slice_size > 0, "Slice size is not positive, desired length is too short"
return (
tokenizer.decode(tokens[:slice_size])
+ sep_str
+ tokenizer.decode(tokens[-slice_size:])
)
def tokenizer_trim_chunks(

View File

@@ -220,3 +220,29 @@ Chat History:
Based on the above, what is a short name to convey the topic of the conversation?
""".strip()
# NOTE: the prompt separation is partially done for efficiency; previously I tried
# to do it all in one prompt with sequential format() calls but this will cause a backend
# error when the document contains any {} as python will expect the {} to be filled by
# format() arguments
CONTEXTUAL_RAG_PROMPT1 = """<document>
{document}
</document>
Here is the chunk we want to situate within the whole document"""
CONTEXTUAL_RAG_PROMPT2 = """<chunk>
{chunk}
</chunk>
Please give a short succinct context to situate this chunk within the overall document
for the purposes of improving search retrieval of the chunk. Answer only with the succinct
context and nothing else. """
CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45
DOCUMENT_SUMMARY_PROMPT = """<document>
{document}
</document>
Please give a short succinct summary of the entire document. Answer only with the succinct
summary and nothing else. """
DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 29

View File

@@ -87,6 +87,9 @@ def _create_indexable_chunks(
metadata_suffix_keyword="",
mini_chunk_texts=None,
large_chunk_reference_ids=[],
doc_summary="",
chunk_context="",
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(
full_embedding=preprocessed_doc["content_embedding"],
mini_chunk_embeddings=[],

View File

@@ -21,9 +21,11 @@ from onyx.llm.factory import get_default_llms
from onyx.llm.factory import get_llm
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import get_llm_contextual_cost
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
@@ -286,3 +288,38 @@ def list_llm_provider_basics(
db_session, user
)
]
@admin_router.get("/provider-contextual-cost")
def get_provider_contextual_cost(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMCost]:
"""
Get the cost of Re-indexing all documents for contextual retrieval.
See https://docs.litellm.ai/docs/completion/token_usage#5-cost_per_token
This includes:
- The cost of invoking the LLM on each chunk-document pair to get
- the doc_summary
- the chunk_context
- The per-token cost of the LLM used to generate the doc_summary and chunk_context
"""
providers = fetch_existing_llm_providers(db_session)
costs = []
for provider in providers:
for model_name in provider.display_model_names or provider.model_names or []:
llm = get_llm(
provider=provider.provider,
model=model_name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
)
cost = get_llm_contextual_cost(llm)
costs.append(
LLMCost(provider=provider.name, model_name=model_name, cost=cost)
)
return costs

View File

@@ -119,3 +119,9 @@ class VisionProviderResponse(LLMProviderView):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]
class LLMCost(BaseModel):
provider: str
model_name: str
cost: float

View File

@@ -12,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
@@ -42,9 +41,6 @@ from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from onyx.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
@@ -58,7 +54,6 @@ from onyx.utils.special_types import JSON_ro
logger = setup_logger()
SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
@@ -357,13 +352,13 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
recency_bias_multiplier=search_pipeline.search_query.recency_bias_multiplier,
)
yield from yield_search_responses(
query,
lambda: search_pipeline.retrieved_sections,
lambda: search_pipeline.reranked_sections,
lambda: search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
query=query,
# give back the merged sections to prevent duplicate docs from appearing in the UI
get_retrieved_sections=lambda: search_pipeline.merged_retrieved_sections,
get_final_context_sections=lambda: search_pipeline.final_context_sections,
search_query_info=search_query_info,
get_section_relevance=lambda: search_pipeline.section_relevance,
search_tool=self,
)
def final_result(self, *args: ToolResponse) -> JSON_ro:
@@ -405,7 +400,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
def yield_search_responses(
query: str,
get_retrieved_sections: Callable[[], list[InferenceSection]],
get_reranked_sections: Callable[[], list[InferenceSection]],
get_final_context_sections: Callable[[], list[InferenceSection]],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
@@ -423,16 +417,6 @@ def yield_search_responses(
),
)
yield ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
context_from_inference_section(section)
for section in get_reranked_sections()
]
),
)
section_relevance = get_section_relevance()
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,

View File

@@ -1,5 +1,4 @@
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceSection
from onyx.prompts.prompt_utils import clean_up_source
@@ -32,10 +31,23 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
return doc_dict
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
return OnyxContext(
content=section.combined_content,
def section_to_llm_doc(section: InferenceSection) -> LlmDoc:
possible_link_chunks = [section.center_chunk] + section.chunks
link: str | None = None
for chunk in possible_link_chunks:
if chunk.source_links:
link = list(chunk.source_links.values())[0]
break
return LlmDoc(
document_id=section.center_chunk.document_id,
content=section.combined_content,
source_type=section.center_chunk.source_type,
semantic_identifier=section.center_chunk.semantic_identifier,
metadata=section.center_chunk.metadata,
updated_at=section.center_chunk.updated_at,
blurb=section.center_chunk.blurb,
link=link,
source_links=section.center_chunk.source_links,
match_highlights=section.center_chunk.match_highlights,
)

View File

@@ -63,7 +63,10 @@ def generate_dummy_chunk(
title_prefix=f"Title prefix for doc {doc_id}",
metadata_suffix_semantic="",
metadata_suffix_keyword="",
doc_summary="",
chunk_context="",
mini_chunk_texts=None,
contextual_rag_reserved_tokens=0,
embeddings=ChunkEmbedding(
full_embedding=generate_random_embedding(embedding_dim),
mini_chunk_embeddings=[],
@@ -78,19 +81,19 @@ def generate_dummy_chunk(
for i in range(number_of_document_sets):
document_set_names.append(f"Document Set {i}")
user_emails: set[str | None] = set()
user_groups: set[str] = set()
external_user_emails: set[str] = set()
external_user_group_ids: set[str] = set()
user_emails: list[str | None] = []
user_groups: list[str] = []
external_user_emails: list[str] = []
external_user_group_ids: list[str] = []
for i in range(number_of_acl_entries):
user_emails.add(f"user_{i}@example.com")
user_groups.add(f"group_{i}")
external_user_emails.add(f"external_user_{i}@example.com")
external_user_group_ids.add(f"external_group_{i}")
user_emails.append(f"user_{i}@example.com")
user_groups.append(f"group_{i}")
external_user_emails.append(f"external_user_{i}@example.com")
external_user_group_ids.append(f"external_group_{i}")
return DocMetadataAwareIndexChunk.from_index_chunk(
index_chunk=chunk,
access=DocumentAccess(
access=DocumentAccess.build(
user_emails=user_emails,
user_groups=user_groups,
external_user_emails=external_user_emails,

View File

@@ -99,6 +99,7 @@ PRESERVED_SEARCH_FIELDS = [
"api_url",
"index_name",
"multipass_indexing",
"enable_contextual_rag",
"model_dim",
"normalize",
"passage_prefix",

View File

@@ -58,6 +58,16 @@ SECTIONS_FOLDER_URL = (
"https://drive.google.com/drive/u/5/folders/1loe6XJ-pJxu9YYPv7cF3Hmz296VNzA33"
)
EXTERNAL_SHARED_FOLDER_URL = (
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
)
EXTERNAL_SHARED_DOCS_IN_FOLDER = [
"https://docs.google.com/document/d/1Sywmv1-H6ENk2GcgieKou3kQHR_0te1mhIUcq8XlcdY"
]
EXTERNAL_SHARED_DOC_SINGLETON = (
"https://docs.google.com/document/d/11kmisDfdvNcw5LYZbkdPVjTOdj-Uc5ma6Jep68xzeeA"
)
SHARED_DRIVE_3_URL = "https://drive.google.com/drive/folders/0AJYm2K_I_vtNUk9PVA"
ADMIN_EMAIL = "admin@onyx-test.com"

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable
from unittest.mock import MagicMock
from unittest.mock import patch
from urllib.parse import urlparse
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
@@ -9,6 +10,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_
from tests.daily.connectors.google_drive.consts_and_utils import (
assert_expected_docs_in_retrieved_docs,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOC_SINGLETON,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_DOCS_IN_FOLDER,
)
from tests.daily.connectors.google_drive.consts_and_utils import (
EXTERNAL_SHARED_FOLDER_URL,
)
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
@@ -100,7 +110,8 @@ def test_include_shared_drives_only_with_size_threshold(
retrieved_docs = load_all_docs(connector)
assert len(retrieved_docs) == 50
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 52
@patch(
@@ -137,7 +148,8 @@ def test_include_shared_drives_only(
+ SECTIONS_FILE_IDS
)
assert len(retrieved_docs) == 51
# 2 extra files from shared drive owned by non-admin and not shared with admin
assert len(retrieved_docs) == 53
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
@@ -294,6 +306,64 @@ def test_folders_only(
)
def test_shared_folder_owned_by_external_user(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_folder_owned_by_external_user")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=False,
include_files_shared_with_me=False,
shared_drive_urls=None,
shared_folder_urls=EXTERNAL_SHARED_FOLDER_URL,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
expected_docs = EXTERNAL_SHARED_DOCS_IN_FOLDER
assert len(retrieved_docs) == len(expected_docs) # 1 for now
assert expected_docs[0] in retrieved_docs[0].id
def test_shared_with_me(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
) -> None:
print("\n\nRunning test_shared_with_me")
connector = google_drive_service_acct_connector_factory(
primary_admin_email=ADMIN_EMAIL,
include_shared_drives=False,
include_my_drives=True,
include_files_shared_with_me=True,
shared_drive_urls=None,
shared_folder_urls=None,
my_drive_emails=None,
)
retrieved_docs = load_all_docs(connector)
print(retrieved_docs)
expected_file_ids = (
ADMIN_FILE_IDS
+ ADMIN_FOLDER_3_FILE_IDS
+ TEST_USER_1_FILE_IDS
+ TEST_USER_2_FILE_IDS
+ TEST_USER_3_FILE_IDS
)
assert_expected_docs_in_retrieved_docs(
retrieved_docs=retrieved_docs,
expected_file_ids=expected_file_ids,
)
retrieved_ids = {urlparse(doc.id).path.split("/")[-2] for doc in retrieved_docs}
for id in retrieved_ids:
print(id)
assert EXTERNAL_SHARED_DOC_SINGLETON.split("/")[-1] in retrieved_ids
assert EXTERNAL_SHARED_DOCS_IN_FOLDER[0].split("/")[-1] in retrieved_ids
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,

View File

@@ -6,7 +6,7 @@ API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
API_SERVER_PORT = os.getenv("API_SERVER_PORT") or "8080"
API_SERVER_URL = f"{API_SERVER_PROTOCOL}://{API_SERVER_HOST}:{API_SERVER_PORT}"
MAX_DELAY = 45
MAX_DELAY = 60
GENERAL_HEADERS = {"Content-Type": "application/json"}

View File

@@ -5,6 +5,7 @@ import requests
from requests.models import Response
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SavedSearchDoc
from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
@@ -97,17 +98,24 @@ class ChatSessionManager:
for data in response_data:
if "rephrased_query" in data:
analyzed.rephrased_query = data["rephrased_query"]
elif "tool_name" in data:
if "tool_name" in data:
analyzed.tool_name = data["tool_name"]
analyzed.tool_result = (
data.get("tool_result")
if analyzed.tool_name == "run_search"
else None
)
elif "relevance_summaries" in data:
if "relevance_summaries" in data:
analyzed.relevance_summaries = data["relevance_summaries"]
elif "answer_piece" in data and data["answer_piece"]:
if "answer_piece" in data and data["answer_piece"]:
analyzed.full_message += data["answer_piece"]
if "top_documents" in data:
assert (
analyzed.top_documents is None
), "top_documents should only be set once"
analyzed.top_documents = [
SavedSearchDoc(**doc) for doc in data["top_documents"]
]
return analyzed

View File

@@ -10,6 +10,7 @@ from pydantic import Field
from onyx.auth.schemas import UserRole
from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.models import SavedSearchDoc
from onyx.db.enums import AccessType
from onyx.server.documents.models import DocumentSource
from onyx.server.documents.models import IndexAttemptSnapshot
@@ -157,7 +158,7 @@ class StreamedResponse(BaseModel):
full_message: str = ""
rephrased_query: str | None = None
tool_name: str | None = None
top_documents: list[dict[str, Any]] | None = None
top_documents: list[SavedSearchDoc] | None = None
relevance_summaries: list[dict[str, Any]] | None = None
tool_result: Any | None = None
user: str | None = None

View File

@@ -1,3 +1,6 @@
import os
import pytest
import requests
from onyx.auth.schemas import UserRole
@@ -6,6 +9,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_user_conversion(reset: None) -> None:
"""
Test that SAML login correctly converts users with non-authenticated roles

View File

@@ -5,6 +5,7 @@ This file contains tests for the following:
- updates the document sets and user groups to remove the connector
- Ensure that deleting a connector that is part of an overlapping document set and/or user group works as expected
"""
import os
from uuid import uuid4
from sqlalchemy.orm import Session
@@ -32,6 +33,13 @@ from tests.integration.common_utils.vespa import vespa_fixture
def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_group_1: DATestUserGroup
user_group_2: DATestUserGroup
is_ee = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
@@ -78,16 +86,17 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
print("Document sets created and synced")
# create user groups
user_group_1: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2: DATestUserGroup = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
if is_ee:
# create user groups
user_group_1 = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
user_group_2 = UserGroupManager.create(
cc_pair_ids=[cc_pair_1.id, cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(user_performing_action=admin_user)
# inject a finished index attempt and index attempt error (exercises foreign key errors)
with Session(get_sqlalchemy_engine()) as db_session:
@@ -147,12 +156,13 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
)
# Update local records to match the database for later comparison
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
doc_set_1.cc_pair_ids = []
doc_set_2.cc_pair_ids = [cc_pair_2.id]
cc_pair_1.groups = []
cc_pair_2.groups = [user_group_2.id]
if is_ee:
cc_pair_2.groups = [user_group_2.id]
else:
cc_pair_2.groups = []
CCPairManager.wait_for_deletion_completion(
cc_pair_id=cc_pair_1.id, user_performing_action=admin_user
@@ -168,11 +178,15 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
verify_deleted=True,
)
cc_pair_2_group_name_expected = []
if is_ee:
cc_pair_2_group_name_expected = [user_group_2.name]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[doc_set_2.name],
group_names=[user_group_2.name],
group_names=cc_pair_2_group_name_expected,
doc_creating_user=admin_user,
verify_deleted=False,
)
@@ -193,15 +207,19 @@ def test_connector_deletion(reset: None, vespa_client: vespa_fixture) -> None:
user_performing_action=admin_user,
)
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
if is_ee:
user_group_1.cc_pair_ids = []
user_group_2.cc_pair_ids = [cc_pair_2.id]
# validate user groups
UserGroupManager.verify(
user_group=user_group_1,
user_performing_action=admin_user,
)
UserGroupManager.verify(
user_group=user_group_2,
user_performing_action=admin_user,
)
def test_connector_deletion_for_overlapping_connectors(
@@ -210,6 +228,13 @@ def test_connector_deletion_for_overlapping_connectors(
"""Checks to make sure that connectors with overlapping documents work properly. Specifically, that the overlapping
document (1) still exists and (2) has the right document set / group post-deletion of one of the connectors.
"""
user_group_1: DATestUserGroup
user_group_2: DATestUserGroup
is_ee = (
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true"
)
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
# create api key
@@ -281,47 +306,48 @@ def test_connector_deletion_for_overlapping_connectors(
doc_creating_user=admin_user,
)
# create a user group and attach it to connector 1
user_group_1: DATestUserGroup = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
if is_ee:
# create a user group and attach it to connector 1
user_group_1 = UserGroupManager.create(
name="Test User Group 1",
cc_pair_ids=[cc_pair_1.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_1],
user_performing_action=admin_user,
)
cc_pair_1.groups = [user_group_1.id]
print("User group 1 created and synced")
print("User group 1 created and synced")
# create a user group and attach it to connector 2
user_group_2: DATestUserGroup = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
# create a user group and attach it to connector 2
user_group_2 = UserGroupManager.create(
name="Test User Group 2",
cc_pair_ids=[cc_pair_2.id],
user_performing_action=admin_user,
)
UserGroupManager.wait_for_sync(
user_groups_to_check=[user_group_2],
user_performing_action=admin_user,
)
cc_pair_2.groups = [user_group_2.id]
print("User group 2 created and synced")
print("User group 2 created and synced")
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# verify vespa document is in the user group
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_1,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
group_names=[user_group_1.name, user_group_2.name],
doc_creating_user=admin_user,
)
# delete connector 1
CCPairManager.pause_cc_pair(
@@ -354,11 +380,15 @@ def test_connector_deletion_for_overlapping_connectors(
# verify the document is not in any document sets
# verify the document is only in user group 2
group_names_expected = []
if is_ee:
group_names_expected = [user_group_2.name]
DocumentManager.verify(
vespa_client=vespa_client,
cc_pair=cc_pair_2,
doc_set_names=[],
group_names=[user_group_2.name],
group_names=group_names_expected,
doc_creating_user=admin_user,
verify_deleted=False,
)

View File

@@ -1,3 +1,6 @@
import os
import pytest
import requests
from onyx.configs.constants import MessageType
@@ -12,6 +15,10 @@ from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history is enterprise only",
)
def test_all_stream_chat_message_objects_outputs(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,5 +1,7 @@
import json
import os
import pytest
import requests
from onyx.configs.constants import MessageType
@@ -16,10 +18,11 @@ from tests.integration.common_utils.test_models import DATestCCPair
from tests.integration.common_utils.test_models import DATestUser
def test_send_message_simple_with_history(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
# create connectors
cc_pair_1: DATestCCPair = CCPairManager.create_from_scratch(
user_performing_action=admin_user,
@@ -53,18 +56,22 @@ def test_send_message_simple_with_history(reset: None) -> None:
response_json = response.json()
# Check that the top document is the correct document
assert response_json["simple_search_docs"][0]["id"] == cc_pair_1.documents[0].id
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[0].id
# assert that the metadata is correct
for doc in cc_pair_1.documents:
found_doc = next(
(x for x in response_json["simple_search_docs"] if x["id"] == doc.id), None
(x for x in response_json["top_documents"] if x["document_id"] == doc.id),
None,
)
assert found_doc
assert found_doc["metadata"]["document_id"] == doc.id
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
@@ -154,6 +161,10 @@ def test_using_reference_docs_with_simple_with_history_api_flow(reset: None) ->
assert response_json["top_documents"][0]["document_id"] == cc_pair_1.documents[2].id
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="/chat/send-message-simple-with-history tests are enterprise only",
)
def test_send_message_simple_with_history_strict_json(
new_admin_user: DATestUser | None,
) -> None:

View File

@@ -2,6 +2,8 @@
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating connector-credential pairs.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -15,6 +17,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and User Group tests are enterprise only",
)
def test_cc_pair_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -2,6 +2,8 @@
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating connectors.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -13,6 +15,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_connector_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -2,6 +2,8 @@
This file takes the happy path to adding a curator to a user group and then tests
the permissions of the curator manipulating credentials.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -12,6 +14,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_credential_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,3 +1,5 @@
import os
import pytest
from requests.exceptions import HTTPError
@@ -10,6 +12,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_doc_set_permissions_setup(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -4,6 +4,8 @@ This file tests the permissions for creating and editing personas for different
- Curators can edit personas that belong exclusively to groups they curate
- Admins can edit all personas
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -13,6 +15,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_persona_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,6 +1,8 @@
"""
This file tests the ability of different user types to set the role of other users.
"""
import os
import pytest
from requests.exceptions import HTTPError
@@ -10,6 +12,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator and user group tests are enterprise only",
)
def test_user_role_setting_permissions(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,6 +1,10 @@
"""
This test tests the happy path for curator permissions
"""
import os
import pytest
from onyx.db.enums import AccessType
from onyx.db.models import UserRole
from onyx.server.documents.models import DocumentSource
@@ -12,6 +16,10 @@ from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator tests are enterprise only",
)
def test_whole_curator_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
@@ -89,6 +97,10 @@ def test_whole_curator_flow(reset: None) -> None:
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Curator tests are enterprise only",
)
def test_global_curator_flow(reset: None) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,3 +1,4 @@
import os
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -63,6 +64,10 @@ def setup_chat_session(reset: None) -> tuple[DATestUser, str]:
return admin_user, str(chat_session.id)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Chat history tests are enterprise only",
)
def test_chat_history_endpoints(
reset: None, setup_chat_session: tuple[DATestUser, str]
) -> None:
@@ -116,6 +121,10 @@ def test_chat_history_endpoints(
assert len(history_response.items) == 0
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Chat history tests are enterprise only",
)
def test_chat_history_csv_export(
reset: None, setup_chat_session: tuple[DATestUser, str]
) -> None:

View File

@@ -1,5 +1,8 @@
import os
from datetime import datetime
import pytest
from onyx.configs.constants import QAFeedbackType
from tests.integration.common_utils.managers.query_history import QueryHistoryManager
from tests.integration.common_utils.test_models import DAQueryHistoryEntry
@@ -47,6 +50,10 @@ def _verify_query_history_pagination(
assert all_expected_sessions == all_retrieved_sessions
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Query history tests are enterprise only",
)
def test_query_history_pagination(reset: None) -> None:
(
admin_user,

View File

@@ -0,0 +1,42 @@
from collections.abc import Callable
import pytest
from onyx.configs.constants import DocumentSource
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.document import DocumentManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.test_models import SimpleTestDocument
DocumentBuilderType = Callable[[list[str]], list[SimpleTestDocument]]
@pytest.fixture
def document_builder(admin_user: DATestUser) -> DocumentBuilderType:
api_key: DATestAPIKey = APIKeyManager.create(
user_performing_action=admin_user,
)
# create connector
cc_pair_1 = CCPairManager.create_from_scratch(
source=DocumentSource.INGESTION_API,
user_performing_action=admin_user,
)
def _document_builder(contents: list[str]) -> list[SimpleTestDocument]:
# seed documents
docs: list[SimpleTestDocument] = [
DocumentManager.seed_doc_with_content(
cc_pair=cc_pair_1,
content=content,
api_key=api_key,
)
for content in contents
]
return docs
return _document_builder

View File

@@ -5,12 +5,11 @@ import pytest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.tests.streaming_endpoints.conftest import DocumentBuilderType
def test_send_message_simple_with_history(reset: None) -> None:
admin_user: DATestUser = UserManager.create(name="admin_user")
def test_send_message_simple_with_history(reset: None, admin_user: DATestUser) -> None:
LLMProviderManager.create(user_performing_action=admin_user)
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
@@ -24,6 +23,44 @@ def test_send_message_simple_with_history(reset: None) -> None:
assert len(response.full_message) > 0
def test_send_message__basic_searches(
reset: None, admin_user: DATestUser, document_builder: DocumentBuilderType
) -> None:
MESSAGE = "run a search for 'test'"
SHORT_DOC_CONTENT = "test"
LONG_DOC_CONTENT = "blah blah blah blah" * 100
LLMProviderManager.create(user_performing_action=admin_user)
short_doc = document_builder([SHORT_DOC_CONTENT])[0]
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message=MESSAGE,
user_performing_action=admin_user,
)
assert response.top_documents is not None
assert len(response.top_documents) == 1
assert response.top_documents[0].document_id == short_doc.id
# make sure this doc is really long so that it will be split into multiple chunks
long_doc = document_builder([LONG_DOC_CONTENT])[0]
# new chat session for simplicity
test_chat_session = ChatSessionManager.create(user_performing_action=admin_user)
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message=MESSAGE,
user_performing_action=admin_user,
)
assert response.top_documents is not None
assert len(response.top_documents) == 2
# short doc should be more relevant and thus first
assert response.top_documents[0].document_id == short_doc.id
assert response.top_documents[1].document_id == long_doc.id
@pytest.mark.skip(
reason="enable for autorun when we have a testing environment with semantically useful data"
)

View File

@@ -8,6 +8,10 @@ This tests the deletion of a user group with the following foreign key constrain
- token_rate_limit (Not Implemented)
- persona
"""
import os
import pytest
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.managers.cc_pair import CCPairManager
from tests.integration.common_utils.managers.credential import CredentialManager
@@ -25,6 +29,10 @@ from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_user_group_deletion(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -1,3 +1,7 @@
import os
import pytest
from onyx.server.documents.models import DocumentSource
from tests.integration.common_utils.constants import NUM_DOCS
from tests.integration.common_utils.managers.api_key import APIKeyManager
@@ -11,6 +15,10 @@ from tests.integration.common_utils.test_models import DATestUserGroup
from tests.integration.common_utils.vespa import vespa_fixture
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_removing_connector(reset: None, vespa_client: vespa_fixture) -> None:
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")

View File

@@ -32,6 +32,8 @@ def create_test_chunk(
match_highlights=[],
updated_at=datetime.now(),
image_file_name=None,
doc_summary="",
chunk_context="",
)

View File

@@ -9,8 +9,6 @@ from onyx.chat.chat_utils import llm_doc_from_inference_section
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import DocumentSource
@@ -19,7 +17,6 @@ from onyx.context.search.models import InferenceSection
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
@@ -81,6 +78,8 @@ def mock_inference_sections() -> list[InferenceSection]:
source_links={0: "https://example.com/doc1"},
match_highlights=[],
image_file_name=None,
doc_summary="",
chunk_context="",
),
chunks=MagicMock(),
),
@@ -104,6 +103,8 @@ def mock_inference_sections() -> list[InferenceSection]:
source_links={0: "https://example.com/doc2"},
match_highlights=[],
image_file_name=None,
doc_summary="",
chunk_context="",
),
chunks=MagicMock(),
),
@@ -120,24 +121,7 @@ def mock_search_results(
@pytest.fixture
def mock_contexts(mock_inference_sections: list[InferenceSection]) -> OnyxContexts:
return OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in mock_inference_sections
]
)
@pytest.fixture
def mock_search_tool(
mock_contexts: OnyxContexts, mock_search_results: list[LlmDoc]
) -> MagicMock:
def mock_search_tool(mock_search_results: list[LlmDoc]) -> MagicMock:
mock_tool = MagicMock(spec=SearchTool)
mock_tool.name = "search"
mock_tool.build_tool_message_content.return_value = "search_response"
@@ -146,7 +130,6 @@ def mock_search_tool(
json.loads(doc.model_dump_json()) for doc in mock_search_results
]
mock_tool.run.return_value = [
ToolResponse(id=SEARCH_DOC_CONTENT_ID, response=mock_contexts),
ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=mock_search_results),
]
mock_tool.tool_definition.return_value = {

View File

@@ -151,6 +151,8 @@ def test_fuzzy_match_quotes_to_docs() -> None:
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)
test_chunk_1 = InferenceChunk(
document_id="test doc 1",
@@ -170,6 +172,8 @@ def test_fuzzy_match_quotes_to_docs() -> None:
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)
test_quotes = [

View File

@@ -19,7 +19,6 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import CitationInfo
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
@@ -33,7 +32,6 @@ from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
@@ -141,7 +139,6 @@ def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
def test_answer_with_search_call(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock,
force_use_tool: ForceUseTool,
expected_tool_args: dict,
@@ -197,25 +194,21 @@ def test_answer_with_search_call(
tool_name="search", tool_args=expected_tool_args
)
assert output[1] == ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=mock_contexts,
)
assert output[2] == ToolResponse(
id="final_context_documents",
response=mock_search_results,
)
assert output[3] == ToolCallFinalResult(
assert output[2] == ToolCallFinalResult(
tool_name="search",
tool_args=expected_tool_args,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[5] == expected_citation
assert output[6] == OnyxAnswerPiece(
assert output[4] == expected_citation
assert output[5] == OnyxAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "
@@ -268,7 +261,6 @@ def test_answer_with_search_call(
def test_answer_with_search_no_tool_calling(
answer_instance: Answer,
mock_search_results: list[LlmDoc],
mock_contexts: OnyxContexts,
mock_search_tool: MagicMock,
) -> None:
answer_instance.graph_config.tooling.tools = [mock_search_tool]
@@ -288,30 +280,26 @@ def test_answer_with_search_no_tool_calling(
output = list(answer_instance.processed_streamed_output)
# Assertions
assert len(output) == 8
assert len(output) == 7
assert output[0] == ToolCallKickoff(
tool_name="search", tool_args=DEFAULT_SEARCH_ARGS
)
assert output[1] == ToolResponse(
id=SEARCH_DOC_CONTENT_ID,
response=mock_contexts,
)
assert output[2] == ToolResponse(
id=FINAL_CONTEXT_DOCUMENTS_ID,
response=mock_search_results,
)
assert output[3] == ToolCallFinalResult(
assert output[2] == ToolCallFinalResult(
tool_name="search",
tool_args=DEFAULT_SEARCH_ARGS,
tool_result=[json.loads(doc.model_dump_json()) for doc in mock_search_results],
)
assert output[4] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
assert output[3] == OnyxAnswerPiece(answer_piece="Based on the search results, ")
expected_citation = CitationInfo(citation_num=1, document_id="doc1")
assert output[5] == expected_citation
assert output[6] == OnyxAnswerPiece(
assert output[4] == expected_citation
assert output[5] == OnyxAnswerPiece(
answer_piece="the answer is abc[[1]](https://example.com/doc1). "
)
assert output[7] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
assert output[6] == OnyxAnswerPiece(answer_piece="This is some other stuff.")
expected_answer = (
"Based on the search results, "

View File

@@ -38,6 +38,8 @@ def create_inference_chunk(
match_highlights=[],
updated_at=None,
image_file_name=None,
doc_summary="",
chunk_context="",
)

View File

@@ -79,7 +79,7 @@ def test_skip_gen_ai_answer_generation_flag(
for res in results:
print(res)
expected_count = 4 if skip_gen_ai_answer_generation else 5
expected_count = 3 if skip_gen_ai_answer_generation else 4
assert len(results) == expected_count
if not skip_gen_ai_answer_generation:
mock_llm.stream.assert_called_once()

View File

@@ -1,5 +1,6 @@
import pytest
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -17,3 +18,13 @@ class MockHeartbeat(IndexingHeartbeatInterface):
@pytest.fixture
def mock_heartbeat() -> MockHeartbeat:
return MockHeartbeat()
@pytest.fixture
def embedder() -> DefaultIndexingEmbedder:
return DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)

View File

@@ -1,25 +1,24 @@
from typing import Any
from unittest.mock import Mock
import pytest
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import process_image_sections
from onyx.llm.utils import MAX_CONTEXT_TOKENS
from tests.unit.onyx.indexing.conftest import MockHeartbeat
@pytest.fixture
def embedder() -> DefaultIndexingEmbedder:
return DefaultIndexingEmbedder(
model_name="intfloat/e5-base-v2",
normalize=True,
query_prefix=None,
passage_prefix=None,
)
def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
@pytest.mark.parametrize("enable_contextual_rag", [True, False])
def test_chunk_document(
embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool
) -> None:
short_section_1 = "This is a short section."
long_section = (
"This is a long section that should be split into multiple chunks. " * 100
@@ -45,9 +44,22 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
)
indexing_documents = process_image_sections([document])
mock_llm_invoke_count = 0
def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock:
nonlocal mock_llm_invoke_count
mock_llm_invoke_count += 1
m = Mock()
m.content = f"Test{mock_llm_invoke_count}"
return m
mock_llm = Mock()
mock_llm.invoke = mock_llm_invoke
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
enable_contextual_rag=enable_contextual_rag,
)
chunks = chunker.chunk(indexing_documents)
@@ -58,6 +70,14 @@ def test_chunk_document(embedder: DefaultIndexingEmbedder) -> None:
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic
rag_tokens = MAX_CONTEXT_TOKENS * (
int(USE_DOCUMENT_SUMMARY) + int(USE_CHUNK_SUMMARY)
)
for chunk in chunks:
assert chunk.contextual_rag_reserved_tokens == (
rag_tokens if enable_contextual_rag else 0
)
def test_chunker_heartbeat(
embedder: DefaultIndexingEmbedder, mock_heartbeat: MockHeartbeat
@@ -78,6 +98,7 @@ def test_chunker_heartbeat(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
callback=mock_heartbeat,
enable_contextual_rag=False,
)
chunks = chunker.chunk(indexing_documents)

View File

@@ -21,7 +21,13 @@ def mock_embedding_model() -> Generator[Mock, None, None]:
yield mock
def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> None:
@pytest.mark.parametrize(
"chunk_context, doc_summary",
[("Test chunk context", "Test document summary"), ("", "")],
)
def test_default_indexing_embedder_embed_chunks(
mock_embedding_model: Mock, chunk_context: str, doc_summary: str
) -> None:
# Setup
embedder = DefaultIndexingEmbedder(
model_name="test-model",
@@ -63,6 +69,9 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N
large_chunk_reference_ids=[],
large_chunk_id=None,
image_file_name=None,
chunk_context=chunk_context,
doc_summary=doc_summary,
contextual_rag_reserved_tokens=200,
)
]
@@ -81,7 +90,7 @@ def test_default_indexing_embedder_embed_chunks(mock_embedding_model: Mock) -> N
# Verify the embedding model was called correctly
mock_embedding_model.return_value.encode.assert_any_call(
texts=["Title: Test chunk"],
texts=[f"Title: {doc_summary}Test chunk{chunk_context}"],
text_type=EmbedTextType.PASSAGE,
large_chunks_present=False,
)

View File

@@ -1,6 +1,8 @@
from typing import Any
from typing import cast
from typing import List
from unittest.mock import Mock
from unittest.mock import patch
import pytest
@@ -9,8 +11,12 @@ from onyx.connectors.models import Document
from onyx.connectors.models import DocumentSource
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import _get_aggregated_chunk_boost_factor
from onyx.indexing.indexing_pipeline import add_contextual_summaries
from onyx.indexing.indexing_pipeline import filter_documents
from onyx.indexing.indexing_pipeline import process_image_sections
from onyx.indexing.models import ChunkEmbedding
from onyx.indexing.models import IndexChunk
from onyx.natural_language_processing.search_nlp_models import (
@@ -166,6 +172,9 @@ def create_test_chunk(
embeddings=ChunkEmbedding(full_embedding=[], mini_chunk_embeddings=[]),
title_embedding=None,
image_file_name=None,
chunk_context="",
doc_summary="",
contextual_rag_reserved_tokens=200,
)
@@ -249,3 +258,76 @@ def test_get_aggregated_boost_factor_individual_failure() -> None:
)
assert "Failed to predict content classification for chunk" in str(exc_info.value)
@patch("onyx.llm.utils.GEN_AI_MAX_TOKENS", 4096)
@pytest.mark.parametrize("enable_contextual_rag", [True, False])
def test_contextual_rag(
embedder: DefaultIndexingEmbedder, enable_contextual_rag: bool
) -> None:
short_section_1 = "This is a short section."
long_section = (
"This is a long section that should be split into multiple chunks. " * 100
)
short_section_2 = "This is another short section."
short_section_3 = "This is another short section again."
short_section_4 = "Final short section."
semantic_identifier = "Test Document"
document = Document(
id="test_doc",
source=DocumentSource.WEB,
semantic_identifier=semantic_identifier,
metadata={"tags": ["tag1", "tag2"]},
doc_updated_at=None,
sections=[
TextSection(text=short_section_1, link="link1"),
TextSection(text=short_section_2, link="link2"),
TextSection(text=long_section, link="link3"),
TextSection(text=short_section_3, link="link4"),
TextSection(text=short_section_4, link="link5"),
],
)
indexing_documents = process_image_sections([document])
mock_llm_invoke_count = 0
def mock_llm_invoke(self: Any, *args: Any, **kwargs: Any) -> Mock:
nonlocal mock_llm_invoke_count
mock_llm_invoke_count += 1
m = Mock()
m.content = f"Test{mock_llm_invoke_count}"
return m
llm_tokenizer = embedder.embedding_model.tokenizer
mock_llm = Mock()
mock_llm.invoke = mock_llm_invoke
chunker = Chunker(
tokenizer=embedder.embedding_model.tokenizer,
enable_multipass=False,
enable_contextual_rag=enable_contextual_rag,
)
chunks = chunker.chunk(indexing_documents)
chunks = add_contextual_summaries(
chunks, mock_llm, llm_tokenizer, chunker.chunk_token_limit * 2
)
assert len(chunks) == 5
assert short_section_1 in chunks[0].content
assert short_section_3 in chunks[-1].content
assert short_section_4 in chunks[-1].content
assert "tag1" in chunks[0].metadata_suffix_keyword
assert "tag2" in chunks[0].metadata_suffix_semantic
doc_summary = "Test1" if enable_contextual_rag else ""
chunk_context = ""
count = 2
for chunk in chunks:
if enable_contextual_rag:
chunk_context = f"Test{count}"
count += 1
assert chunk.doc_summary == doc_summary
assert chunk.chunk_context == chunk_context

View File

@@ -140,12 +140,12 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
],
tools=tools,
tool_choice=None,
max_tokens=None,
stream=False,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)
@@ -286,10 +286,10 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
],
tools=tools,
tool_choice=None,
max_tokens=None,
stream=True,
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)

View File

@@ -45,7 +45,7 @@ export function ActionsTable({ tools }: { tools: ToolSnapshot[] }) {
className="mr-1 my-auto cursor-pointer"
onClick={() =>
router.push(
`/admin/tools/edit/${tool.id}?u=${Date.now()}`
`/admin/actions/edit/${tool.id}?u=${Date.now()}`
)
}
/>

View File

@@ -1,5 +1,8 @@
export const LLM_PROVIDERS_ADMIN_URL = "/api/admin/llm/provider";
export const LLM_CONTEXTUAL_COST_ADMIN_URL =
"/api/admin/llm/provider-contextual-cost";
export const EMBEDDING_PROVIDERS_ADMIN_URL =
"/api/admin/embedding/embedding-provider";

View File

@@ -143,6 +143,15 @@ function Main() {
</Text>
</div>
<div>
<Text className="font-semibold">Contextual RAG</Text>
<Text className="text-text-700">
{searchSettings.enable_contextual_rag
? "Enabled"
: "Disabled"}
</Text>
</div>
<div>
<Text className="font-semibold">
Disable Reranking for Streaming

View File

@@ -281,7 +281,7 @@ export default function AddConnector({
return (
<Formik
initialValues={{
...createConnectorInitialValues(connector),
...createConnectorInitialValues(connector, currentCredential),
...Object.fromEntries(
connectorConfigs[connector].advanced_values.map((field) => [
field.name,

View File

@@ -148,8 +148,7 @@ export function Explorer({
clearTimeout(timeoutId);
}
let doSearch = true;
if (doSearch) {
if (query && query.trim() !== "") {
router.replace(
`/admin/documents/explorer?query=${encodeURIComponent(query)}`
);

View File

@@ -26,9 +26,18 @@ export enum EmbeddingPrecision {
BFLOAT16 = "bfloat16",
}
export interface LLMContextualCost {
provider: string;
model_name: string;
cost: number;
}
export interface AdvancedSearchConfiguration {
index_name: string | null;
multipass_indexing: boolean;
enable_contextual_rag: boolean;
contextual_rag_llm_name: string | null;
contextual_rag_llm_provider: string | null;
multilingual_expansion: string[];
disable_rerank_for_streaming: boolean;
api_url: string | null;

View File

@@ -3,7 +3,11 @@ import { Formik, Form, FormikProps, FieldArray, Field } from "formik";
import * as Yup from "yup";
import { TrashIcon } from "@/components/icons/icons";
import { FaPlus } from "react-icons/fa";
import { AdvancedSearchConfiguration, EmbeddingPrecision } from "../interfaces";
import {
AdvancedSearchConfiguration,
EmbeddingPrecision,
LLMContextualCost,
} from "../interfaces";
import {
BooleanFormField,
Label,
@@ -12,6 +16,13 @@ import {
} from "@/components/admin/connectors/Field";
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
import { StringOrNumberOption } from "@/components/Dropdown";
import useSWR from "swr";
import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "../../configuration/llm/constants";
import { getDisplayNameForModel } from "@/lib/hooks";
import { errorHandlingFetcher } from "@/lib/fetcher";
// Number of tokens to show cost calculation for
const COST_CALCULATION_TOKENS = 1_000_000;
interface AdvancedEmbeddingFormPageProps {
updateAdvancedEmbeddingDetails: (
@@ -45,14 +56,66 @@ const AdvancedEmbeddingFormPage = forwardRef<
},
ref
) => {
// Fetch contextual costs
const { data: contextualCosts, error: costError } = useSWR<
LLMContextualCost[]
>(LLM_CONTEXTUAL_COST_ADMIN_URL, errorHandlingFetcher);
const llmOptions: StringOrNumberOption[] = React.useMemo(
() =>
(contextualCosts || []).map((cost) => {
return {
name: getDisplayNameForModel(cost.model_name),
value: cost.model_name,
};
}),
[contextualCosts]
);
// Helper function to format cost as USD
const formatCost = (cost: number) => {
return new Intl.NumberFormat("en-US", {
style: "currency",
currency: "USD",
}).format(cost);
};
// Get cost info for selected model
const getSelectedModelCost = (modelName: string | null) => {
if (!contextualCosts || !modelName) return null;
return contextualCosts.find((cost) => cost.model_name === modelName);
};
// Get the current value for the selector based on the parent state
const getCurrentLLMValue = React.useMemo(() => {
if (!advancedEmbeddingDetails.contextual_rag_llm_name) return null;
return advancedEmbeddingDetails.contextual_rag_llm_name;
}, [advancedEmbeddingDetails.contextual_rag_llm_name]);
return (
<div className="py-4 rounded-lg max-w-4xl px-4 mx-auto">
<Formik
innerRef={ref}
initialValues={advancedEmbeddingDetails}
initialValues={{
...advancedEmbeddingDetails,
contextual_rag_llm: getCurrentLLMValue,
}}
validationSchema={Yup.object().shape({
multilingual_expansion: Yup.array().of(Yup.string()),
multipass_indexing: Yup.boolean(),
enable_contextual_rag: Yup.boolean(),
contextual_rag_llm: Yup.string()
.nullable()
.test(
"required-if-contextual-rag",
"LLM must be selected when Contextual RAG is enabled",
function (value) {
const enableContextualRag = this.parent.enable_contextual_rag;
console.log("enableContextualRag", enableContextualRag);
console.log("value", value);
return !enableContextualRag || value !== null;
}
),
disable_rerank_for_streaming: Yup.boolean(),
num_rerank: Yup.number()
.required("Number of results to rerank is required")
@@ -79,10 +142,26 @@ const AdvancedEmbeddingFormPage = forwardRef<
validate={(values) => {
// Call updateAdvancedEmbeddingDetails for each changed field
Object.entries(values).forEach(([key, value]) => {
updateAdvancedEmbeddingDetails(
key as keyof AdvancedSearchConfiguration,
value
);
if (key === "contextual_rag_llm") {
const selectedModel = (contextualCosts || []).find(
(cost) => cost.model_name === value
);
if (selectedModel) {
updateAdvancedEmbeddingDetails(
"contextual_rag_llm_provider",
selectedModel.provider
);
updateAdvancedEmbeddingDetails(
"contextual_rag_llm_name",
selectedModel.model_name
);
}
} else {
updateAdvancedEmbeddingDetails(
key as keyof AdvancedSearchConfiguration,
value
);
}
});
// Run validation and report errors
@@ -96,6 +175,23 @@ const AdvancedEmbeddingFormPage = forwardRef<
.shape({
multilingual_expansion: Yup.array().of(Yup.string()),
multipass_indexing: Yup.boolean(),
enable_contextual_rag: Yup.boolean(),
contextual_rag_llm: Yup.string()
.nullable()
.test(
"required-if-contextual-rag",
"LLM must be selected when Contextual RAG is enabled",
function (value) {
const enableContextualRag =
this.parent.enable_contextual_rag;
console.log(
"enableContextualRag2",
enableContextualRag
);
console.log("value2", value);
return !enableContextualRag || value !== null;
}
),
disable_rerank_for_streaming: Yup.boolean(),
num_rerank: Yup.number()
.required("Number of results to rerank is required")
@@ -190,6 +286,56 @@ const AdvancedEmbeddingFormPage = forwardRef<
label="Disable Rerank for Streaming"
name="disable_rerank_for_streaming"
/>
<BooleanFormField
subtext="Enable contextual RAG for all chunk sizes."
optional
label="Contextual RAG"
name="enable_contextual_rag"
/>
<div>
<SelectorFormField
name="contextual_rag_llm"
label="Contextual RAG LLM"
subtext={
costError
? "Error loading LLM models. Please try again later."
: !contextualCosts
? "Loading available LLM models..."
: values.enable_contextual_rag
? "Select the LLM model to use for contextual RAG processing."
: "Enable Contextual RAG above to select an LLM model."
}
options={llmOptions}
disabled={
!values.enable_contextual_rag ||
!contextualCosts ||
!!costError
}
/>
{values.enable_contextual_rag &&
values.contextual_rag_llm &&
!costError && (
<div className="mt-2 text-sm text-text-600">
{contextualCosts ? (
<>
Estimated cost for processing{" "}
{COST_CALCULATION_TOKENS.toLocaleString()} tokens:{" "}
<span className="font-medium">
{getSelectedModelCost(values.contextual_rag_llm)
? formatCost(
getSelectedModelCost(
values.contextual_rag_llm
)!.cost
)
: "Cost information not available"}
</span>
</>
) : (
"Loading cost information..."
)}
</div>
)}
</div>
<NumberInput
description="Number of results to rerank"
optional={false}

View File

@@ -64,6 +64,9 @@ export default function EmbeddingForm() {
useState<AdvancedSearchConfiguration>({
index_name: "",
multipass_indexing: true,
enable_contextual_rag: false,
contextual_rag_llm_name: null,
contextual_rag_llm_provider: null,
multilingual_expansion: [],
disable_rerank_for_streaming: false,
api_url: null,
@@ -152,6 +155,9 @@ export default function EmbeddingForm() {
setAdvancedEmbeddingDetails({
index_name: searchSettings.index_name,
multipass_indexing: searchSettings.multipass_indexing,
enable_contextual_rag: searchSettings.enable_contextual_rag,
contextual_rag_llm_name: searchSettings.contextual_rag_llm_name,
contextual_rag_llm_provider: searchSettings.contextual_rag_llm_provider,
multilingual_expansion: searchSettings.multilingual_expansion,
disable_rerank_for_streaming:
searchSettings.disable_rerank_for_streaming,
@@ -197,7 +203,9 @@ export default function EmbeddingForm() {
searchSettings?.embedding_precision !=
advancedEmbeddingDetails.embedding_precision ||
searchSettings?.reduced_dimension !=
advancedEmbeddingDetails.reduced_dimension;
advancedEmbeddingDetails.reduced_dimension ||
searchSettings?.enable_contextual_rag !=
advancedEmbeddingDetails.enable_contextual_rag;
const updateSearch = useCallback(async () => {
if (!selectedProvider) {
@@ -384,6 +392,14 @@ export default function EmbeddingForm() {
advancedEmbeddingDetails.reduced_dimension && (
<li>Reduced dimension modification</li>
)}
{(searchSettings?.enable_contextual_rag !=
advancedEmbeddingDetails.enable_contextual_rag ||
searchSettings?.contextual_rag_llm_name !=
advancedEmbeddingDetails.contextual_rag_llm_name ||
searchSettings?.contextual_rag_llm_provider !=
advancedEmbeddingDetails.contextual_rag_llm_provider) && (
<li>Contextual RAG modification</li>
)}
</ul>
</div>
</div>
@@ -471,6 +487,11 @@ export default function EmbeddingForm() {
};
const handleReIndex = async () => {
console.log("handleReIndex");
console.log(selectedProvider);
console.log(advancedEmbeddingDetails);
console.log(rerankingDetails);
console.log(reindexType);
if (!selectedProvider) {
return;
}

View File

@@ -14,7 +14,7 @@ export default function LoginPage({
authTypeMetadata,
nextUrl,
searchParams,
showPageRedirect,
hidePageRedirect,
}: {
authUrl: string | null;
authTypeMetadata: AuthTypeMetadata | null;
@@ -24,7 +24,7 @@ export default function LoginPage({
[key: string]: string | string[] | undefined;
}
| undefined;
showPageRedirect?: boolean;
hidePageRedirect?: boolean;
}) {
useSendAuthRequiredMessage();
return (
@@ -75,7 +75,7 @@ export default function LoginPage({
<div className="flex flex-col gap-y-2 items-center"></div>
</>
)}
{showPageRedirect && (
{!hidePageRedirect && (
<p className="text-center mt-4">
Don&apos;t have an account?{" "}
<span

View File

@@ -72,6 +72,7 @@ const Page = async (props: {
authTypeMetadata={authTypeMetadata}
nextUrl={nextUrl!}
searchParams={searchParams}
hidePageRedirect={true}
/>
</AuthFlowContainer>
</div>

View File

@@ -91,7 +91,7 @@ export function AgenticToggle({
>
<div className="flex items-center space-x-2 mb-3">
<h3 className="text-sm font-semibold text-neutral-900">
Agent Search (BETA)
Agent Search
</h3>
</div>
<p className="text-xs text-neutral-600 dark:text-neutral-700 mb-2">

View File

@@ -347,7 +347,6 @@ export default function NRFPage({
<p className="p-4">Loading login info</p>
) : authType == "basic" ? (
<LoginPage
showPageRedirect
authUrl={null}
authTypeMetadata={{
authType: authType as AuthType,

View File

@@ -34,11 +34,12 @@
/* -------------------------------------------------------
* 2. Keep special, custom, or near-duplicate background
* ------------------------------------------------------- */
--background: #fefcfa; /* slightly off-white, keep it */
--background: #fefcfa; /* slightly off-white */
--background-50: #fffdfb; /* a little lighter than background but not quite white */
--input-background: #fefcfa;
--input-border: #f1eee8;
--text-text: #f4f2ed;
--background-dark: #e9e6e0;
--background-dark: #141414;
--new-background: #ebe7de;
--new-background-light: #d9d1c0;
--background-chatbar: #f5f3ee;
@@ -234,6 +235,7 @@
--text-text: #1d1d1d;
--background-dark: #252525;
--background-50: #252525;
/* --new-background: #fff; */
--new-background: #2c2c2c;

Some files were not shown because too many files have changed in this diff Show More