mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-29 03:22:43 +00:00
Compare commits
50 Commits
cli/v0.1.2
...
dane/index
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b5e30b829 | ||
|
|
a307900e4f | ||
|
|
f92a3c3d60 | ||
|
|
a33afcf912 | ||
|
|
c730850e81 | ||
|
|
b671bf4d4e | ||
|
|
6b7d9f4cfb | ||
|
|
399e251d85 | ||
|
|
6c38a28cf6 | ||
|
|
93d2f6d552 | ||
|
|
9124a6110d | ||
|
|
26850a42b3 | ||
|
|
c837d1ba80 | ||
|
|
fac7887542 | ||
|
|
6985661dcd | ||
|
|
3e2a10ce9d | ||
|
|
389eb6c281 | ||
|
|
ff88d1886b | ||
|
|
18dac2ba71 | ||
|
|
96cd5bb751 | ||
|
|
30a7c40c55 | ||
|
|
641fb61c45 | ||
|
|
6f8d9cfdd7 | ||
|
|
2784e42cfe | ||
|
|
4f5fc65428 | ||
|
|
8fcdd3a3fb | ||
|
|
3b7c53aeb1 | ||
|
|
ea58e82aed | ||
|
|
bd35585785 | ||
|
|
cf9bd7e511 | ||
|
|
b5dd17a371 | ||
|
|
d62d0c1864 | ||
|
|
2c92742c62 | ||
|
|
1e1402e4f1 | ||
|
|
440818a082 | ||
|
|
bd9f40d1c1 | ||
|
|
c85e090c13 | ||
|
|
d72df59063 | ||
|
|
867442bc54 | ||
|
|
f752761e46 | ||
|
|
a760d1cf33 | ||
|
|
acffd55ce4 | ||
|
|
3a4be4a7d9 | ||
|
|
7c0e7eddbd | ||
|
|
2e5763c9ab | ||
|
|
5c45345521 | ||
|
|
0665f31a7d | ||
|
|
17442ed2d0 | ||
|
|
5b0c2f3c18 | ||
|
|
cff564eb6a |
@@ -787,6 +787,10 @@ MINI_CHUNK_SIZE = 150
|
||||
# This is the number of regular chunks per large chunk
|
||||
LARGE_CHUNK_RATIO = 4
|
||||
|
||||
# The maximum number of chunks that can be held for 1 document processing batch
|
||||
# The purpose of this is to set an upper bound on memory usage
|
||||
MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000)
|
||||
|
||||
# Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
|
||||
@@ -5,6 +5,7 @@ accidentally reaches the vector DB layer will fail loudly instead of timing
|
||||
out against a nonexistent Vespa/OpenSearch instance.
|
||||
"""
|
||||
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -66,7 +67,7 @@ class DisabledDocumentIndex(DocumentIndex):
|
||||
# ------------------------------------------------------------------
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002
|
||||
index_batch_params: IndexBatchParams, # noqa: ARG002
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
raise RuntimeError(VECTOR_DB_DISABLED_ERROR)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
@@ -206,7 +207,7 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[DocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -226,8 +227,8 @@ class Indexable(abc.ABC):
|
||||
it is done automatically outside of this code.
|
||||
|
||||
Parameters:
|
||||
- chunks: Document chunks with all of the information needed for indexing to the document
|
||||
index.
|
||||
- chunks: Document chunks with all of the information needed for
|
||||
indexing to the document index.
|
||||
- tenant_id: The tenant id of the user whose chunks are being indexed
|
||||
- large_chunks_enabled: Whether large chunks are enabled
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import abc
|
||||
from collections.abc import Iterable
|
||||
from typing import Self
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -209,10 +210,10 @@ class Indexable(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
This is often a batch operation including chunks from multiple
|
||||
documents.
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
import json
|
||||
from collections import defaultdict
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
from opensearchpy import NotFoundError
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
@@ -350,7 +351,7 @@ class OpenSearchOldDocumentIndex(OldDocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
@@ -646,10 +647,10 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata, # noqa: ARG002
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
"""Indexes a list of document chunks into the document index.
|
||||
"""Indexes an iterable of document chunks into the document index.
|
||||
|
||||
Groups chunks by document ID and for each document, deletes existing
|
||||
chunks and indexes the new chunks in bulk.
|
||||
@@ -672,29 +673,34 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
document is newly indexed or had already existed and was just
|
||||
updated.
|
||||
"""
|
||||
# Group chunks by document ID.
|
||||
doc_id_to_chunks: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(
|
||||
list
|
||||
total_chunks = sum(
|
||||
cc.new_chunk_cnt
|
||||
for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values()
|
||||
)
|
||||
for chunk in chunks:
|
||||
doc_id_to_chunks[chunk.source_document.id].append(chunk)
|
||||
logger.debug(
|
||||
f"[OpenSearchDocumentIndex] Indexing {len(chunks)} chunks from {len(doc_id_to_chunks)} "
|
||||
f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} "
|
||||
f"documents for index {self._index_name}."
|
||||
)
|
||||
|
||||
document_indexing_results: list[DocumentInsertionRecord] = []
|
||||
# Try to index per-document.
|
||||
for _, chunks in doc_id_to_chunks.items():
|
||||
deleted_doc_ids: set[str] = set()
|
||||
# Buffer chunks per document as they arrive from the iterable.
|
||||
# When the document ID changes flush the buffered chunks.
|
||||
current_doc_id: str | None = None
|
||||
current_chunks: list[DocMetadataAwareIndexChunk] = []
|
||||
|
||||
def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None:
|
||||
assert len(doc_chunks) > 0, "doc_chunks is empty"
|
||||
|
||||
# Create a batch of OpenSearch-formatted chunks for bulk insertion.
|
||||
# Do this before deleting existing chunks to reduce the amount of
|
||||
# time the document index has no content for a given document, and
|
||||
# to reduce the chance of entering a state where we delete chunks,
|
||||
# then some error happens, and never successfully index new chunks.
|
||||
# Since we are doing this in batches, an error occurring midway
|
||||
# can result in a state where chunks are deleted and not all the
|
||||
# new chunks have been indexed.
|
||||
chunk_batch: list[DocumentChunk] = [
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk) for chunk in chunks
|
||||
_convert_onyx_chunk_to_opensearch_document(chunk)
|
||||
for chunk in doc_chunks
|
||||
]
|
||||
onyx_document: Document = chunks[0].source_document
|
||||
onyx_document: Document = doc_chunks[0].source_document
|
||||
# First delete the doc's chunks from the index. This is so that
|
||||
# there are no dangling chunks in the index, in the event that the
|
||||
# new document's content contains fewer chunks than the previous
|
||||
@@ -703,22 +709,43 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
# if the chunk count has actually decreased. This assumes that
|
||||
# overlapping chunks are perfectly overwritten. If we can't
|
||||
# guarantee that then we need the code as-is.
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed.
|
||||
document_insertion_record = DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
if onyx_document.id not in deleted_doc_ids:
|
||||
num_chunks_deleted = self.delete(
|
||||
onyx_document.id, onyx_document.chunk_count
|
||||
)
|
||||
deleted_doc_ids.add(onyx_document.id)
|
||||
# If we see that chunks were deleted we assume the doc already
|
||||
# existed. We record the result before bulk_index_documents
|
||||
# runs. If indexing raises, this entire result list is discarded
|
||||
# by the caller's retry logic, so early recording is safe.
|
||||
document_indexing_results.append(
|
||||
DocumentInsertionRecord(
|
||||
document_id=onyx_document.id,
|
||||
already_existed=num_chunks_deleted > 0,
|
||||
)
|
||||
)
|
||||
# Now index. This will raise if a chunk of the same ID exists, which
|
||||
# we do not expect because we should have deleted all chunks.
|
||||
self._client.bulk_index_documents(
|
||||
documents=chunk_batch,
|
||||
tenant_state=self._tenant_state,
|
||||
)
|
||||
document_indexing_results.append(document_insertion_record)
|
||||
|
||||
for chunk in chunks:
|
||||
doc_id = chunk.source_document.id
|
||||
if doc_id != current_doc_id:
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
current_doc_id = doc_id
|
||||
current_chunks = [chunk]
|
||||
elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH:
|
||||
_flush_chunks(current_chunks)
|
||||
current_chunks = [chunk]
|
||||
else:
|
||||
current_chunks.append(chunk)
|
||||
|
||||
if current_chunks:
|
||||
_flush_chunks(current_chunks)
|
||||
|
||||
return document_indexing_results
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ import re
|
||||
import time
|
||||
import urllib
|
||||
import zipfile
|
||||
from collections.abc import Iterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
@@ -461,7 +462,7 @@ class VespaIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> set[OldDocumentInsertionRecord]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import random
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterable
|
||||
from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
@@ -8,6 +10,7 @@ import httpx
|
||||
from pydantic import BaseModel
|
||||
from retry import retry
|
||||
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER
|
||||
from onyx.configs.app_configs import RERANK_COUNT
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
@@ -318,7 +321,7 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
def index(
|
||||
self,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
chunks: Iterable[DocMetadataAwareIndexChunk],
|
||||
indexing_metadata: IndexingMetadata,
|
||||
) -> list[DocumentInsertionRecord]:
|
||||
doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff
|
||||
@@ -338,22 +341,31 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
# Vespa has restrictions on valid characters, yet document IDs come from
|
||||
# external w.r.t. this class. We need to sanitize them.
|
||||
cleaned_chunks: list[DocMetadataAwareIndexChunk] = [
|
||||
clean_chunk_id_copy(chunk) for chunk in chunks
|
||||
]
|
||||
assert len(cleaned_chunks) == len(
|
||||
chunks
|
||||
), "Bug: Cleaned chunks and input chunks have different lengths."
|
||||
#
|
||||
# Instead of materializing all cleaned chunks upfront, we stream them
|
||||
# through a generator that cleans IDs and builds the original-ID mapping
|
||||
# incrementally as chunks flow into Vespa.
|
||||
def _clean_and_track(
|
||||
chunks_iter: Iterable[DocMetadataAwareIndexChunk],
|
||||
id_map: dict[str, str],
|
||||
seen_ids: set[str],
|
||||
) -> Generator[DocMetadataAwareIndexChunk, None, None]:
|
||||
"""Cleans chunk IDs and builds the original-ID mapping
|
||||
incrementally as chunks flow through, avoiding a separate
|
||||
materialization pass."""
|
||||
for chunk in chunks_iter:
|
||||
original_id = chunk.source_document.id
|
||||
cleaned = clean_chunk_id_copy(chunk)
|
||||
cleaned_id = cleaned.source_document.id
|
||||
# Needed so the final DocumentInsertionRecord returned can have
|
||||
# the original document ID. cleaned_chunks might not contain IDs
|
||||
# exactly as callers supplied them.
|
||||
id_map[cleaned_id] = original_id
|
||||
seen_ids.add(cleaned_id)
|
||||
yield cleaned
|
||||
|
||||
# Needed so the final DocumentInsertionRecord returned can have the
|
||||
# original document ID. cleaned_chunks might not contain IDs exactly as
|
||||
# callers supplied them.
|
||||
new_document_id_to_original_document_id: dict[str, str] = dict()
|
||||
for i, cleaned_chunk in enumerate(cleaned_chunks):
|
||||
old_chunk = chunks[i]
|
||||
new_document_id_to_original_document_id[
|
||||
cleaned_chunk.source_document.id
|
||||
] = old_chunk.source_document.id
|
||||
new_document_id_to_original_document_id: dict[str, str] = {}
|
||||
all_cleaned_doc_ids: set[str] = set()
|
||||
|
||||
existing_docs: set[str] = set()
|
||||
|
||||
@@ -409,8 +421,16 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
# Insert new Vespa documents.
|
||||
for chunk_batch in batch_generator(cleaned_chunks, BATCH_SIZE):
|
||||
# Insert new Vespa documents, streaming through the cleaning
|
||||
# pipeline so chunks are never fully materialized.
|
||||
cleaned_chunks = _clean_and_track(
|
||||
chunks,
|
||||
new_document_id_to_original_document_id,
|
||||
all_cleaned_doc_ids,
|
||||
)
|
||||
for chunk_batch in batch_generator(
|
||||
cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
batch_index_vespa_chunks(
|
||||
chunks=chunk_batch,
|
||||
index_name=self._index_name,
|
||||
@@ -419,10 +439,6 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
executor=executor,
|
||||
)
|
||||
|
||||
all_cleaned_doc_ids: set[str] = {
|
||||
chunk.source_document.id for chunk in cleaned_chunks
|
||||
}
|
||||
|
||||
return [
|
||||
DocumentInsertionRecord(
|
||||
document_id=new_document_id_to_original_document_id[cleaned_doc_id],
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Generator
|
||||
|
||||
@@ -19,7 +21,8 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.indexing_pipeline import index_doc_batch_prepare
|
||||
from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import ChunkEnrichmentContext
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
@@ -85,14 +88,21 @@ class DocumentIndexingBatchAdapter:
|
||||
) as transaction:
|
||||
yield transaction
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: DocumentBatchPrepareContext,
|
||||
) -> BuildMetadataAwareChunksResult:
|
||||
"""Enrich chunks with access, document sets, boosts, token counts, and hierarchy."""
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> DocumentChunkEnricher:
|
||||
"""Do all DB lookups once and return a per-chunk enricher."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -102,67 +112,30 @@ class DocumentIndexingBatchAdapter:
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_access_info = get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
doc_id_to_document_set = {
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
return DocumentChunkEnricher(
|
||||
doc_id_to_access_info=get_access_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int] = {
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
}
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
doc_id: 0 for doc_id in updatable_ids
|
||||
}
|
||||
for chunk in chunks_with_embeddings:
|
||||
if chunk.source_document.id in doc_id_to_new_chunk_cnt:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
|
||||
# Get ancestor hierarchy node IDs for each document
|
||||
doc_id_to_ancestor_ids = self._get_ancestor_ids_for_documents(
|
||||
context.updatable_docs, tenant_id
|
||||
)
|
||||
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=doc_id_to_access_info.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(
|
||||
doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
context.id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in context.id_to_boost_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
ancestor_hierarchy_node_ids=doc_id_to_ancestor_ids[
|
||||
chunk.source_document.id
|
||||
],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
return BuildMetadataAwareChunksResult(
|
||||
chunks=access_aware_chunks,
|
||||
doc_id_to_previous_chunk_cnt=doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=doc_id_to_new_chunk_cnt,
|
||||
user_file_id_to_raw_text={},
|
||||
user_file_id_to_token_count={},
|
||||
),
|
||||
doc_id_to_document_set={
|
||||
document_id: document_sets
|
||||
for document_id, document_sets in fetch_document_sets_for_documents(
|
||||
document_ids=updatable_ids, db_session=self.db_session
|
||||
)
|
||||
},
|
||||
doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents(
|
||||
context.updatable_docs, tenant_id
|
||||
),
|
||||
id_to_boost_map=context.id_to_boost_map,
|
||||
doc_id_to_previous_chunk_cnt={
|
||||
document_id: chunk_count
|
||||
for document_id, chunk_count in fetch_chunk_counts_for_documents(
|
||||
document_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
},
|
||||
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
|
||||
no_access=no_access,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def _get_ancestor_ids_for_documents(
|
||||
@@ -203,7 +176,7 @@ class DocumentIndexingBatchAdapter:
|
||||
context: DocumentBatchPrepareContext,
|
||||
updatable_chunk_data: list[UpdatableChunkData],
|
||||
filtered_documents: list[Document],
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None:
|
||||
"""Finalize DB updates, store plaintext, and mark docs as indexed."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
@@ -227,7 +200,7 @@ class DocumentIndexingBatchAdapter:
|
||||
|
||||
update_docs_chunk_count__no_commit(
|
||||
document_ids=updatable_ids,
|
||||
doc_id_to_chunk_count=result.doc_id_to_new_chunk_cnt,
|
||||
doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
|
||||
@@ -249,3 +222,52 @@ class DocumentIndexingBatchAdapter:
|
||||
)
|
||||
|
||||
self.db_session.commit()
|
||||
|
||||
|
||||
class DocumentChunkEnricher:
|
||||
"""Pre-computed metadata for per-chunk enrichment of connector documents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
doc_id_to_access_info: dict[str, DocumentAccess],
|
||||
doc_id_to_document_set: dict[str, list[str]],
|
||||
doc_id_to_ancestor_ids: dict[str, list[int]],
|
||||
id_to_boost_map: dict[str, int],
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int],
|
||||
doc_id_to_new_chunk_cnt: dict[str, int],
|
||||
no_access: DocumentAccess,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
self._doc_id_to_access_info = doc_id_to_access_info
|
||||
self._doc_id_to_document_set = doc_id_to_document_set
|
||||
self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids
|
||||
self._id_to_boost_map = id_to_boost_map
|
||||
self._no_access = no_access
|
||||
self._tenant_id = tenant_id
|
||||
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
|
||||
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=self._doc_id_to_access_info.get(
|
||||
chunk.source_document.id, self._no_access
|
||||
),
|
||||
document_sets=set(
|
||||
self._doc_id_to_document_set.get(chunk.source_document.id, [])
|
||||
),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=(
|
||||
self._id_to_boost_map[chunk.source_document.id]
|
||||
if chunk.source_document.id in self._id_to_boost_map
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=self._tenant_id,
|
||||
aggregated_chunk_boost_factor=score,
|
||||
ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[
|
||||
chunk.source_document.id
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import datetime
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator
|
||||
from uuid import UUID
|
||||
|
||||
@@ -24,7 +27,8 @@ from onyx.db.user_file import fetch_persona_ids_for_user_files
|
||||
from onyx.db.user_file import fetch_user_project_ids_for_user_files
|
||||
from onyx.file_store.utils import store_user_file_plaintext
|
||||
from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext
|
||||
from onyx.indexing.models import BuildMetadataAwareChunksResult
|
||||
from onyx.indexing.models import ChunkEnrichmentContext
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
@@ -101,13 +105,20 @@ class UserFileIndexingAdapter:
|
||||
f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}"
|
||||
)
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: DocumentBatchPrepareContext,
|
||||
) -> BuildMetadataAwareChunksResult:
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> UserFileChunkEnricher:
|
||||
"""Do all DB lookups and pre-compute file metadata from chunks."""
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int)
|
||||
content_by_file: dict[str, list[str]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1
|
||||
content_by_file[chunk.source_document.id].append(chunk.content)
|
||||
|
||||
no_access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
@@ -117,7 +128,6 @@ class UserFileIndexingAdapter:
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
user_file_id_to_project_ids = fetch_user_project_ids_for_user_files(
|
||||
user_file_ids=updatable_ids,
|
||||
db_session=self.db_session,
|
||||
@@ -138,17 +148,6 @@ class UserFileIndexingAdapter:
|
||||
)
|
||||
}
|
||||
|
||||
user_file_id_to_new_chunk_cnt: dict[str, int] = {
|
||||
user_file_id: len(
|
||||
[
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == user_file_id
|
||||
]
|
||||
)
|
||||
for user_file_id in updatable_ids
|
||||
}
|
||||
|
||||
# Initialize tokenizer used for token count calculation
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
@@ -163,15 +162,9 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_raw_text: dict[str, str] = {}
|
||||
user_file_id_to_token_count: dict[str, int | None] = {}
|
||||
for user_file_id in updatable_ids:
|
||||
user_file_chunks = [
|
||||
chunk
|
||||
for chunk in chunks_with_embeddings
|
||||
if chunk.source_document.id == user_file_id
|
||||
]
|
||||
if user_file_chunks:
|
||||
combined_content = " ".join(
|
||||
[chunk.content for chunk in user_file_chunks]
|
||||
)
|
||||
contents = content_by_file.get(user_file_id)
|
||||
if contents:
|
||||
combined_content = " ".join(contents)
|
||||
user_file_id_to_raw_text[str(user_file_id)] = combined_content
|
||||
token_count = (
|
||||
len(llm_tokenizer.encode(combined_content)) if llm_tokenizer else 0
|
||||
@@ -181,28 +174,16 @@ class UserFileIndexingAdapter:
|
||||
user_file_id_to_raw_text[str(user_file_id)] = ""
|
||||
user_file_id_to_token_count[str(user_file_id)] = None
|
||||
|
||||
access_aware_chunks = [
|
||||
DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=user_file_id_to_access.get(chunk.source_document.id, no_access),
|
||||
document_sets=set(),
|
||||
user_project=user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=user_file_id_to_persona_ids.get(chunk.source_document.id, []),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
]
|
||||
|
||||
return BuildMetadataAwareChunksResult(
|
||||
chunks=access_aware_chunks,
|
||||
return UserFileChunkEnricher(
|
||||
user_file_id_to_access=user_file_id_to_access,
|
||||
user_file_id_to_project_ids=user_file_id_to_project_ids,
|
||||
user_file_id_to_persona_ids=user_file_id_to_persona_ids,
|
||||
doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=user_file_id_to_new_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt),
|
||||
user_file_id_to_raw_text=user_file_id_to_raw_text,
|
||||
user_file_id_to_token_count=user_file_id_to_token_count,
|
||||
no_access=no_access,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
def _notify_assistant_owners_if_files_ready(
|
||||
@@ -246,8 +227,9 @@ class UserFileIndexingAdapter:
|
||||
context: DocumentBatchPrepareContext,
|
||||
updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002
|
||||
filtered_documents: list[Document], # noqa: ARG002
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None:
|
||||
assert isinstance(enrichment, UserFileChunkEnricher)
|
||||
user_file_ids = [doc.id for doc in context.updatable_docs]
|
||||
|
||||
user_files = (
|
||||
@@ -263,8 +245,10 @@ class UserFileIndexingAdapter:
|
||||
user_file.last_project_sync_at = datetime.datetime.now(
|
||||
datetime.timezone.utc
|
||||
)
|
||||
user_file.chunk_count = result.doc_id_to_new_chunk_cnt[str(user_file.id)]
|
||||
user_file.token_count = result.user_file_id_to_token_count[
|
||||
user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get(
|
||||
str(user_file.id), 0
|
||||
)
|
||||
user_file.token_count = enrichment.user_file_id_to_token_count[
|
||||
str(user_file.id)
|
||||
]
|
||||
|
||||
@@ -276,8 +260,54 @@ class UserFileIndexingAdapter:
|
||||
# Store the plaintext in the file store for faster retrieval
|
||||
# NOTE: this creates its own session to avoid committing the overall
|
||||
# transaction.
|
||||
for user_file_id, raw_text in result.user_file_id_to_raw_text.items():
|
||||
for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items():
|
||||
store_user_file_plaintext(
|
||||
user_file_id=UUID(user_file_id),
|
||||
plaintext_content=raw_text,
|
||||
)
|
||||
|
||||
|
||||
class UserFileChunkEnricher:
|
||||
"""Pre-computed metadata for per-chunk enrichment of user-uploaded files."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_file_id_to_access: dict[str, DocumentAccess],
|
||||
user_file_id_to_project_ids: dict[str, list[int]],
|
||||
user_file_id_to_persona_ids: dict[str, list[int]],
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int],
|
||||
doc_id_to_new_chunk_cnt: dict[str, int],
|
||||
user_file_id_to_raw_text: dict[str, str],
|
||||
user_file_id_to_token_count: dict[str, int | None],
|
||||
no_access: DocumentAccess,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
self._user_file_id_to_access = user_file_id_to_access
|
||||
self._user_file_id_to_project_ids = user_file_id_to_project_ids
|
||||
self._user_file_id_to_persona_ids = user_file_id_to_persona_ids
|
||||
self._no_access = no_access
|
||||
self._tenant_id = tenant_id
|
||||
self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt
|
||||
self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt
|
||||
self.user_file_id_to_raw_text = user_file_id_to_raw_text
|
||||
self.user_file_id_to_token_count = user_file_id_to_token_count
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=chunk,
|
||||
access=self._user_file_id_to_access.get(
|
||||
chunk.source_document.id, self._no_access
|
||||
),
|
||||
document_sets=set(),
|
||||
user_project=self._user_file_id_to_project_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
personas=self._user_file_id_to_persona_ids.get(
|
||||
chunk.source_document.id, []
|
||||
),
|
||||
boost=DEFAULT_BOOST,
|
||||
tenant_id=self._tenant_id,
|
||||
aggregated_chunk_boost_factor=score,
|
||||
)
|
||||
|
||||
88
backend/onyx/indexing/chunk_batch_store.py
Normal file
88
backend/onyx/indexing/chunk_batch_store.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pickle
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
class ChunkBatchStore:
|
||||
"""Manages serialization of embedded chunks to a temporary directory.
|
||||
|
||||
Owns the temp directory lifetime and provides save/load/stream/scrub
|
||||
operations.
|
||||
|
||||
Use as a context manager to ensure cleanup::
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
store.save(chunks, batch_idx=0)
|
||||
for chunk in store.stream():
|
||||
...
|
||||
"""
|
||||
|
||||
_EXT = ".pkl"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tmpdir_ctx = tempfile.TemporaryDirectory(prefix="onyx_embeddings_")
|
||||
self._tmpdir: Path | None = None
|
||||
|
||||
# -- context manager -----------------------------------------------------
|
||||
|
||||
def __enter__(self) -> "ChunkBatchStore":
|
||||
self._tmpdir = Path(self._tmpdir_ctx.__enter__())
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: object) -> None:
|
||||
self._tmpdir_ctx.__exit__(*exc)
|
||||
self._tmpdir = None
|
||||
|
||||
@property
|
||||
def _dir(self) -> Path:
|
||||
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
|
||||
return self._tmpdir
|
||||
|
||||
# -- storage primitives --------------------------------------------------
|
||||
|
||||
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
|
||||
"""Serialize a batch of embedded chunks to disk."""
|
||||
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
|
||||
pickle.dump(chunks, f)
|
||||
|
||||
def _load(self, batch_file: Path) -> list[IndexChunk]:
|
||||
"""Deserialize a batch of embedded chunks from a file."""
|
||||
with open(batch_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _batch_files(self) -> list[Path]:
|
||||
"""Return batch files sorted by numeric index."""
|
||||
return sorted(
|
||||
self._dir.glob(f"batch_*{self._EXT}"),
|
||||
key=lambda p: int(p.stem.removeprefix("batch_")),
|
||||
)
|
||||
|
||||
# -- higher-level operations ---------------------------------------------
|
||||
|
||||
def stream(self) -> Iterator[IndexChunk]:
|
||||
"""Yield all chunks across all batch files.
|
||||
|
||||
Each call returns a fresh generator, so the data can be iterated
|
||||
multiple times (e.g. once per document index).
|
||||
"""
|
||||
for batch_file in self._batch_files():
|
||||
yield from self._load(batch_file)
|
||||
|
||||
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
|
||||
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
|
||||
|
||||
When a document fails embedding in batch N, earlier batches may
|
||||
already contain successfully embedded chunks for that document.
|
||||
This ensures the output is all-or-nothing per document.
|
||||
"""
|
||||
for batch_file in self._batch_files():
|
||||
batch_chunks = self._load(batch_file)
|
||||
cleaned = [
|
||||
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
|
||||
]
|
||||
if len(cleaned) != len(batch_chunks):
|
||||
with open(batch_file, "wb") as f:
|
||||
pickle.dump(cleaned, f)
|
||||
@@ -1,5 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -9,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
|
||||
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
@@ -43,10 +47,12 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexingBatchAdapter
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
@@ -63,6 +69,7 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -91,6 +98,20 @@ class IndexingPipelineResult(BaseModel):
|
||||
|
||||
failures: list[ConnectorFailure]
|
||||
|
||||
@classmethod
|
||||
def empty(cls, total_docs: int) -> "IndexingPipelineResult":
|
||||
return cls(
|
||||
new_docs=0,
|
||||
total_docs=total_docs,
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
|
||||
|
||||
class ChunkEmbeddingResult(BaseModel):
|
||||
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
|
||||
connector_failures: list[ConnectorFailure]
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
@@ -139,6 +160,110 @@ def _upsert_documents_in_db(
|
||||
)
|
||||
|
||||
|
||||
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
|
||||
"""Extract document IDs from a list of connector failures."""
|
||||
return {f.failed_document.document_id for f in failures if f.failed_document}
|
||||
|
||||
|
||||
def _embed_chunks_to_store(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str,
|
||||
request_id: str | None,
|
||||
store: ChunkBatchStore,
|
||||
) -> ChunkEmbeddingResult:
|
||||
"""Embed chunks in batches, spilling each batch to *store*.
|
||||
|
||||
If a document fails embedding in any batch, its chunks are excluded from
|
||||
all batches (including earlier ones already written) so that the output
|
||||
is all-or-nothing per document.
|
||||
"""
|
||||
successful_chunk_ids: list[tuple[int, str]] = []
|
||||
all_embedding_failures: list[ConnectorFailure] = []
|
||||
# Track failed doc IDs across all batches so that a failure in batch N
|
||||
# causes chunks for that doc to be skipped in batch N+1 and stripped
|
||||
# from earlier batches.
|
||||
all_failed_doc_ids: set[str] = set()
|
||||
|
||||
for batch_idx, chunk_batch in enumerate(
|
||||
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
# Skip chunks belonging to documents that failed in earlier batches.
|
||||
chunk_batch = [
|
||||
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
|
||||
]
|
||||
if not chunk_batch:
|
||||
continue
|
||||
|
||||
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
|
||||
|
||||
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
|
||||
chunks=chunk_batch,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
all_embedding_failures.extend(embedding_failures)
|
||||
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
|
||||
|
||||
# Only keep successfully embedded chunks for non-failed docs.
|
||||
chunks_with_embeddings = [
|
||||
c
|
||||
for c in chunks_with_embeddings
|
||||
if c.source_document.id not in all_failed_doc_ids
|
||||
]
|
||||
|
||||
successful_chunk_ids.extend(
|
||||
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
|
||||
)
|
||||
|
||||
store.save(chunks_with_embeddings, batch_idx)
|
||||
del chunks_with_embeddings
|
||||
|
||||
# Scrub earlier batches for docs that failed in later batches.
|
||||
if all_failed_doc_ids:
|
||||
store.scrub_failed_docs(all_failed_doc_ids)
|
||||
successful_chunk_ids = [
|
||||
(chunk_id, doc_id)
|
||||
for chunk_id, doc_id in successful_chunk_ids
|
||||
if doc_id not in all_failed_doc_ids
|
||||
]
|
||||
|
||||
return ChunkEmbeddingResult(
|
||||
successful_chunk_ids=successful_chunk_ids,
|
||||
connector_failures=all_embedding_failures,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def embed_and_stream(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str,
|
||||
request_id: str | None,
|
||||
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
|
||||
"""Embed chunks to disk and yield a ``(result, store)`` pair.
|
||||
|
||||
The store owns the temp directory — files are cleaned up when the context
|
||||
manager exits.
|
||||
|
||||
Usage::
|
||||
|
||||
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
|
||||
for chunk in store.stream():
|
||||
...
|
||||
"""
|
||||
with ChunkBatchStore() as store:
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
store=store,
|
||||
)
|
||||
yield result, store
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
@@ -637,6 +762,29 @@ def add_contextual_summaries(
|
||||
return chunks
|
||||
|
||||
|
||||
def _verify_indexing_completeness(
|
||||
insertion_records: list[DocumentInsertionRecord],
|
||||
write_failures: list[ConnectorFailure],
|
||||
embedding_failed_doc_ids: set[str],
|
||||
updatable_ids: list[str],
|
||||
document_index_name: str,
|
||||
) -> None:
|
||||
"""Verify that every updatable document was either indexed or reported as failed."""
|
||||
all_returned_doc_ids = (
|
||||
{r.document_id for r in insertion_records}
|
||||
| {f.failed_document.document_id for f in write_failures if f.failed_document}
|
||||
| embedding_failed_doc_ids
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
f"This should never happen. "
|
||||
f"This occured for document index {document_index_name}"
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
@@ -672,12 +820,7 @@ def index_doc_batch(
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
context = adapter.prepare(filtered_documents, ignore_time_skip)
|
||||
if not context:
|
||||
return IndexingPipelineResult(
|
||||
new_docs=0,
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=0,
|
||||
failures=[],
|
||||
)
|
||||
return IndexingPipelineResult.empty(len(filtered_documents))
|
||||
|
||||
# Convert documents to IndexingDocument objects with processed section
|
||||
# logger.debug("Processing image sections")
|
||||
@@ -716,117 +859,99 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
embed_chunks_with_failure_handling(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
if chunks
|
||||
else ([], [])
|
||||
)
|
||||
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
|
||||
embedding_result,
|
||||
chunk_store,
|
||||
):
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
boost_score=1.0,
|
||||
)
|
||||
for chunk_id, document_id in embedding_result.successful_chunk_ids
|
||||
]
|
||||
|
||||
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk.chunk_id,
|
||||
document_id=chunk.source_document.id,
|
||||
boost_score=score,
|
||||
)
|
||||
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
|
||||
]
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
# with Vespa can occur.
|
||||
with adapter.lock_context(context.updatable_docs):
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
# to resolve this, an update of the last modified field at the end of this loop
|
||||
# always triggers a final metadata sync via the celery queue
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=chunks_with_embeddings,
|
||||
chunk_content_scores=chunk_content_scores,
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
embedding_failed_doc_ids = _get_failed_doc_ids(
|
||||
embedding_result.connector_failures
|
||||
)
|
||||
|
||||
short_descriptor_list = [chunk.to_short_descriptor() for chunk in result.chunks]
|
||||
short_descriptor_log = str(short_descriptor_list)[:1024]
|
||||
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
|
||||
# Filter to only successfully embedded chunks so
|
||||
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
|
||||
embedded_chunks = [
|
||||
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
|
||||
]
|
||||
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
|
||||
for document_index in document_indices:
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
chunks=result.chunks,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=result.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=result.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
# Acquires a lock on the documents so that no other process can modify
|
||||
# them. Not needed until here, since this is when the actual race
|
||||
# condition with vector db can occur.
|
||||
with adapter.lock_context(context.updatable_docs):
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=tenant_id,
|
||||
chunks=embedded_chunks,
|
||||
)
|
||||
|
||||
all_returned_doc_ids: set[str] = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
f"This occured for document index {document_index.__class__.__name__}"
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
result=result,
|
||||
)
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
|
||||
None
|
||||
)
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
|
||||
None
|
||||
)
|
||||
|
||||
for document_index in document_indices:
|
||||
|
||||
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for chunk in chunk_store.stream():
|
||||
yield enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
insertion_records, write_failures = (
|
||||
write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
make_chunks=_enriched_stream,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
)
|
||||
|
||||
_verify_indexing_completeness(
|
||||
insertion_records=insertion_records,
|
||||
write_failures=write_failures,
|
||||
embedding_failed_doc_ids=embedding_failed_doc_ids,
|
||||
updatable_ids=updatable_ids,
|
||||
document_index_name=document_index.__class__.__name__,
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
enrichment=enricher,
|
||||
)
|
||||
|
||||
assert primary_doc_idx_insertion_records is not None
|
||||
assert primary_doc_idx_vector_db_write_failures is not None
|
||||
return IndexingPipelineResult(
|
||||
new_docs=len(
|
||||
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
|
||||
new_docs=sum(
|
||||
1 for r in primary_doc_idx_insertion_records if not r.already_existed
|
||||
),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(chunks_with_embeddings),
|
||||
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
|
||||
total_chunks=len(embedding_result.successful_chunk_ids),
|
||||
failures=primary_doc_idx_vector_db_write_failures
|
||||
+ embedding_result.connector_failures,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -235,12 +235,16 @@ class UpdatableChunkData(BaseModel):
|
||||
boost_score: float
|
||||
|
||||
|
||||
class BuildMetadataAwareChunksResult(BaseModel):
|
||||
chunks: list[DocMetadataAwareIndexChunk]
|
||||
class ChunkEnrichmentContext(Protocol):
|
||||
"""Returned by prepare_enrichment. Holds pre-computed metadata lookups
|
||||
and provides per-chunk enrichment."""
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
user_file_id_to_raw_text: dict[str, str]
|
||||
user_file_id_to_token_count: dict[str, int | None]
|
||||
|
||||
def enrich_chunk(
|
||||
self, chunk: IndexChunk, score: float
|
||||
) -> DocMetadataAwareIndexChunk: ...
|
||||
|
||||
|
||||
class IndexingBatchAdapter(Protocol):
|
||||
@@ -254,18 +258,24 @@ class IndexingBatchAdapter(Protocol):
|
||||
) -> Generator[TransactionalContext, None, None]:
|
||||
"""Provide a transaction/row-lock context for critical updates."""
|
||||
|
||||
def build_metadata_aware_chunks(
|
||||
def prepare_enrichment(
|
||||
self,
|
||||
chunks_with_embeddings: list[IndexChunk],
|
||||
chunk_content_scores: list[float],
|
||||
tenant_id: str,
|
||||
context: "DocumentBatchPrepareContext",
|
||||
) -> BuildMetadataAwareChunksResult: ...
|
||||
tenant_id: str,
|
||||
chunks: list[DocAwareChunk],
|
||||
) -> ChunkEnrichmentContext:
|
||||
"""Prepare per-chunk enrichment data (access, document sets, boost, etc.).
|
||||
|
||||
Precondition: ``chunks`` have already been through the embedding step
|
||||
(i.e. they are ``IndexChunk`` instances with populated embeddings,
|
||||
passed here as the base ``DocAwareChunk`` type).
|
||||
"""
|
||||
...
|
||||
|
||||
def post_index(
|
||||
self,
|
||||
context: "DocumentBatchPrepareContext",
|
||||
updatable_chunk_data: list[UpdatableChunkData],
|
||||
filtered_documents: list[Document],
|
||||
result: BuildMetadataAwareChunksResult,
|
||||
enrichment: ChunkEnrichmentContext,
|
||||
) -> None: ...
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from http import HTTPStatus
|
||||
from itertools import chain
|
||||
from itertools import groupby
|
||||
|
||||
import httpx
|
||||
|
||||
@@ -28,22 +31,22 @@ def _log_insufficient_storage_error(e: Exception) -> None:
|
||||
|
||||
def write_chunks_to_vector_db_with_backoff(
|
||||
document_index: DocumentIndex,
|
||||
chunks: list[DocMetadataAwareIndexChunk],
|
||||
make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]],
|
||||
index_batch_params: IndexBatchParams,
|
||||
) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]:
|
||||
"""Tries to insert all chunks in one large batch. If that batch fails for any reason,
|
||||
goes document by document to isolate the failure(s).
|
||||
|
||||
IMPORTANT: must pass in whole documents at a time not individual chunks, since the
|
||||
vector DB interface assumes that all chunks for a single document are present.
|
||||
vector DB interface assumes that all chunks for a single document are present. The
|
||||
chunks must also be in contiguous batches
|
||||
"""
|
||||
|
||||
# first try to write the chunks to the vector db
|
||||
try:
|
||||
return (
|
||||
list(
|
||||
document_index.index(
|
||||
chunks=chunks,
|
||||
chunks=make_chunks(),
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
),
|
||||
@@ -60,14 +63,16 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
# wait a couple seconds just to give the vector db a chance to recover
|
||||
time.sleep(2)
|
||||
|
||||
# try writing each doc one by one
|
||||
chunks_for_docs: dict[str, list[DocMetadataAwareIndexChunk]] = defaultdict(list)
|
||||
for chunk in chunks:
|
||||
chunks_for_docs[chunk.source_document.id].append(chunk)
|
||||
|
||||
insertion_records: list[DocumentInsertionRecord] = []
|
||||
failures: list[ConnectorFailure] = []
|
||||
for doc_id, chunks_for_doc in chunks_for_docs.items():
|
||||
|
||||
def key(chunk: DocMetadataAwareIndexChunk) -> str:
|
||||
return chunk.source_document.id
|
||||
|
||||
for doc_id, chunks_for_doc in groupby(make_chunks(), key=key):
|
||||
first_chunk = next(chunks_for_doc)
|
||||
chunks_for_doc = chain([first_chunk], chunks_for_doc)
|
||||
|
||||
try:
|
||||
insertion_records.extend(
|
||||
document_index.index(
|
||||
@@ -87,9 +92,7 @@ def write_chunks_to_vector_db_with_backoff(
|
||||
ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
document_id=doc_id,
|
||||
document_link=(
|
||||
chunks_for_doc[0].get_link() if chunks_for_doc else None
|
||||
),
|
||||
document_link=first_chunk.get_link(),
|
||||
),
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
"""prepare_enrichment must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
@@ -153,15 +153,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
doc = chunk.source_document
|
||||
context = DocumentBatchPrepareContext(updatable_docs=[doc], id_to_boost_map={})
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@@ -190,15 +188,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
assert len(result.chunks) == 1
|
||||
aware_chunk = result.chunks[0]
|
||||
assert project.id in aware_chunk.user_project
|
||||
assert aware_chunk.personas == []
|
||||
|
||||
@@ -229,14 +225,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert persona.id in aware_chunk.personas
|
||||
assert project.id in aware_chunk.user_project
|
||||
|
||||
@@ -261,14 +256,13 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert aware_chunk.personas == []
|
||||
assert aware_chunk.user_project == []
|
||||
|
||||
@@ -300,12 +294,11 @@ class TestAdapterWritesBothMetadataFields:
|
||||
updatable_docs=[chunk.source_document], id_to_boost_map={}
|
||||
)
|
||||
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=TEST_TENANT_ID,
|
||||
chunks=[chunk],
|
||||
)
|
||||
aware_chunk = enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
aware_chunk = result.chunks[0]
|
||||
assert set(aware_chunk.personas) == {persona_a.id, persona_b.id}
|
||||
|
||||
@@ -6,6 +6,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
@@ -21,6 +22,7 @@ from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
)
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
from tests.external_dependency_unit.document_index.conftest import EMBEDDING_DIM
|
||||
from tests.external_dependency_unit.document_index.conftest import make_chunk
|
||||
@@ -201,3 +203,25 @@ class TestDocumentIndexNew:
|
||||
assert len(result_map) == 2
|
||||
assert result_map[existing_doc] is True
|
||||
assert result_map[new_doc] is False
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndexNew],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
doc_id = f"test_gen_{uuid.uuid4().hex[:8]}"
|
||||
metadata = make_indexing_metadata([doc_id], old_counts=[0], new_counts=[3])
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk(doc_id, chunk_id=i)
|
||||
|
||||
results = document_index.index(
|
||||
chunks=chunk_gen(), indexing_metadata=metadata
|
||||
)
|
||||
|
||||
assert len(results) == 1
|
||||
assert results[0].document_id == doc_id
|
||||
assert results[0].already_existed is False
|
||||
|
||||
@@ -5,6 +5,7 @@ These tests assume Vespa and OpenSearch are running.
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -166,3 +167,29 @@ class TestDocumentIndexOld:
|
||||
batch_retrieval=True,
|
||||
)
|
||||
assert len(inference_chunks) == 0
|
||||
|
||||
def test_index_accepts_generator(
|
||||
self,
|
||||
document_indices: list[DocumentIndex],
|
||||
tenant_context: None, # noqa: ARG002
|
||||
) -> None:
|
||||
"""index() accepts a generator (any iterable), not just a list."""
|
||||
for document_index in document_indices:
|
||||
|
||||
def chunk_gen() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for i in range(3):
|
||||
yield make_chunk("test_doc_gen", chunk_id=i)
|
||||
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt={"test_doc_gen": 0},
|
||||
doc_id_to_new_chunk_cnt={"test_doc_gen": 3},
|
||||
tenant_id=get_current_tenant_id(),
|
||||
large_chunks_enabled=False,
|
||||
)
|
||||
|
||||
results = document_index.index(chunk_gen(), index_batch_params)
|
||||
|
||||
assert len(results) == 1
|
||||
record = results.pop()
|
||||
assert record.document_id == "test_doc_gen"
|
||||
assert record.already_existed is False
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.opensearch_document_index import (
|
||||
OpenSearchDocumentIndex,
|
||||
)
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int,
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
"""Creates a minimal DocMetadataAwareIndexChunk for testing."""
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
sections=[TextSection(text="test", link="http://test.com")],
|
||||
source=DocumentSource.FILE,
|
||||
semantic_identifier="test_doc",
|
||||
metadata={},
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links={0: "http://test.com"},
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings={"full_embedding": [0.1] * 10, "mini_chunk_embeddings": []},
|
||||
title_embedding=[0.1] * 10,
|
||||
tenant_id="test_tenant",
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
ancestor_hierarchy_node_ids=[],
|
||||
)
|
||||
|
||||
|
||||
def _make_index() -> OpenSearchDocumentIndex:
|
||||
"""Creates an OpenSearchDocumentIndex with a mocked client."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.bulk_index_documents = MagicMock()
|
||||
|
||||
tenant_state = TenantState(tenant_id="test_tenant", multitenant=False)
|
||||
|
||||
index = OpenSearchDocumentIndex.__new__(OpenSearchDocumentIndex)
|
||||
index._index_name = "test_index"
|
||||
index._client = mock_client
|
||||
index._tenant_state = tenant_state
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def _make_metadata(doc_id: str, chunk_count: int) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=0,
|
||||
new_chunk_cnt=chunk_count,
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_under_batch_limit_flushes_once() -> None:
|
||||
"""A document with fewer chunks than MAX_CHUNKS_PER_DOC_BATCH should flush once."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 50
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert index._client.bulk_index_documents.call_count == 1
|
||||
batch_arg = index._client.bulk_index_documents.call_args_list[0]
|
||||
assert len(batch_arg.kwargs["documents"]) == num_chunks
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_over_batch_limit_flushes_multiple_times() -> None:
|
||||
"""A document with more chunks than MAX_CHUNKS_PER_DOC_BATCH should flush multiple times."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 250 chunks / 100 per batch = 3 flushes (100 + 100 + 50)
|
||||
assert index._client.bulk_index_documents.call_count == 3
|
||||
batch_sizes = [
|
||||
len(call.kwargs["documents"])
|
||||
for call in index._client.bulk_index_documents.call_args_list
|
||||
]
|
||||
assert batch_sizes == [100, 100, 50]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_exactly_at_batch_limit() -> None:
|
||||
"""A document with exactly MAX_CHUNKS_PER_DOC_BATCH chunks should flush once
|
||||
(the flush happens on the next chunk, not at the boundary)."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 100
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 100 chunks hit the >= check on chunk 101 which doesn't exist,
|
||||
# so final flush handles all 100
|
||||
# Actually: the elif fires when len(current_chunks) >= 100, which happens
|
||||
# when current_chunks has 100 items and the 101st chunk arrives.
|
||||
# With exactly 100 chunks, the 100th chunk makes len == 99, then appended -> 100.
|
||||
# No 101st chunk arrives, so the final flush handles all 100.
|
||||
assert index._client.bulk_index_documents.call_count == 1
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_single_doc_one_over_batch_limit() -> None:
|
||||
"""101 chunks for one doc: first 100 flushed when the 101st arrives, then
|
||||
the 101st is flushed at the end."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 101
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
assert index._client.bulk_index_documents.call_count == 2
|
||||
batch_sizes = [
|
||||
len(call.kwargs["documents"])
|
||||
for call in index._client.bulk_index_documents.call_args_list
|
||||
]
|
||||
assert batch_sizes == [100, 1]
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_multiple_docs_each_under_limit_flush_per_doc() -> None:
|
||||
"""Multiple documents each under the batch limit should flush once per document."""
|
||||
index = _make_index()
|
||||
chunks = []
|
||||
for doc_idx in range(3):
|
||||
doc_id = f"doc_{doc_idx}"
|
||||
for chunk_idx in range(50):
|
||||
chunks.append(_make_chunk(doc_id, chunk_idx))
|
||||
|
||||
metadata = IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
f"doc_{i}": IndexingMetadata.ChunkCounts(old_chunk_cnt=0, new_chunk_cnt=50)
|
||||
for i in range(3)
|
||||
},
|
||||
)
|
||||
|
||||
with patch.object(index, "delete", return_value=0):
|
||||
index.index(chunks, metadata)
|
||||
|
||||
# 3 documents = 3 flushes (one per doc boundary + final)
|
||||
assert index._client.bulk_index_documents.call_count == 3
|
||||
|
||||
|
||||
@patch(
|
||||
"onyx.document_index.opensearch.opensearch_document_index.MAX_CHUNKS_PER_DOC_BATCH",
|
||||
100,
|
||||
)
|
||||
def test_delete_called_once_per_document() -> None:
|
||||
"""Even with multiple flushes for a single document, delete should only be
|
||||
called once per document."""
|
||||
index = _make_index()
|
||||
doc_id = "doc_1"
|
||||
num_chunks = 250
|
||||
chunks = [_make_chunk(doc_id, i) for i in range(num_chunks)]
|
||||
metadata = _make_metadata(doc_id, num_chunks)
|
||||
|
||||
with patch.object(index, "delete", return_value=0) as mock_delete:
|
||||
index.index(chunks, metadata)
|
||||
|
||||
mock_delete.assert_called_once_with(doc_id, None)
|
||||
@@ -0,0 +1,152 @@
|
||||
"""Unit tests for VespaDocumentIndex.index().
|
||||
|
||||
These tests mock all external I/O (HTTP calls, thread pools) and verify
|
||||
the streaming logic, ID cleaning/mapping, and DocumentInsertionRecord
|
||||
construction.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo
|
||||
from onyx.document_index.interfaces_new import IndexingMetadata
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_chunk(
|
||||
doc_id: str,
|
||||
chunk_id: int = 0,
|
||||
content: str = "test content",
|
||||
) -> DocMetadataAwareIndexChunk:
|
||||
doc = Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test_doc",
|
||||
sections=[TextSection(text=content, link=None)],
|
||||
source=DocumentSource.NOT_APPLICABLE,
|
||||
metadata={},
|
||||
)
|
||||
index_chunk = IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb=content[:50],
|
||||
content=content,
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=doc,
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
access = DocumentAccess.build(
|
||||
user_emails=[],
|
||||
user_groups=[],
|
||||
external_user_emails=[],
|
||||
external_user_group_ids=[],
|
||||
is_public=True,
|
||||
)
|
||||
return DocMetadataAwareIndexChunk.from_index_chunk(
|
||||
index_chunk=index_chunk,
|
||||
access=access,
|
||||
document_sets=set(),
|
||||
user_project=[],
|
||||
personas=[],
|
||||
boost=0,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
tenant_id="test_tenant",
|
||||
)
|
||||
|
||||
|
||||
def _make_indexing_metadata(
|
||||
doc_ids: list[str],
|
||||
old_counts: list[int],
|
||||
new_counts: list[int],
|
||||
) -> IndexingMetadata:
|
||||
return IndexingMetadata(
|
||||
doc_id_to_chunk_cnt_diff={
|
||||
doc_id: IndexingMetadata.ChunkCounts(
|
||||
old_chunk_cnt=old,
|
||||
new_chunk_cnt=new,
|
||||
)
|
||||
for doc_id, old, new in zip(doc_ids, old_counts, new_counts)
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _stub_enrich(
|
||||
doc_id: str,
|
||||
old_chunk_cnt: int,
|
||||
) -> EnrichedDocumentIndexingInfo:
|
||||
"""Build an EnrichedDocumentIndexingInfo that says 'no chunks to delete'
|
||||
when old_chunk_cnt == 0, or 'has existing chunks' otherwise."""
|
||||
return EnrichedDocumentIndexingInfo(
|
||||
doc_id=doc_id,
|
||||
chunk_start_index=0,
|
||||
old_version=False,
|
||||
chunk_end_index=old_chunk_cnt,
|
||||
)
|
||||
|
||||
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.batch_index_vespa_chunks")
|
||||
@patch("onyx.document_index.vespa.vespa_document_index.delete_vespa_chunks")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.get_document_chunk_ids",
|
||||
return_value=[],
|
||||
)
|
||||
@patch("onyx.document_index.vespa.vespa_document_index._enrich_basic_chunk_info")
|
||||
@patch(
|
||||
"onyx.document_index.vespa.vespa_document_index.BATCH_SIZE",
|
||||
3,
|
||||
)
|
||||
def test_index_respects_batch_size(
|
||||
mock_enrich: MagicMock,
|
||||
mock_get_chunk_ids: MagicMock, # noqa: ARG001
|
||||
mock_delete: MagicMock, # noqa: ARG001
|
||||
mock_batch_index: MagicMock,
|
||||
) -> None:
|
||||
"""When chunks exceed BATCH_SIZE, batch_index_vespa_chunks is called
|
||||
multiple times with correctly sized batches."""
|
||||
mock_enrich.return_value = _stub_enrich("doc1", old_chunk_cnt=0)
|
||||
|
||||
index = VespaDocumentIndex(
|
||||
index_name="test_index",
|
||||
tenant_state=TenantState(tenant_id="test_tenant", multitenant=False),
|
||||
large_chunks_enabled=False,
|
||||
httpx_client=MagicMock(),
|
||||
)
|
||||
|
||||
chunks = [_make_chunk("doc1", chunk_id=i) for i in range(7)]
|
||||
metadata = _make_indexing_metadata(["doc1"], old_counts=[0], new_counts=[7])
|
||||
|
||||
results = index.index(chunks=chunks, indexing_metadata=metadata)
|
||||
|
||||
assert len(results) == 1
|
||||
|
||||
# With BATCH_SIZE=3 and 7 chunks: batches of 3, 3, 1
|
||||
assert mock_batch_index.call_count == 3
|
||||
batch_sizes = [len(c.kwargs["chunks"]) for c in mock_batch_index.call_args_list]
|
||||
assert batch_sizes == [3, 3, 1]
|
||||
|
||||
# Verify all chunks are accounted for and in order
|
||||
all_indexed = [
|
||||
chunk for c in mock_batch_index.call_args_list for chunk in c.kwargs["chunks"]
|
||||
]
|
||||
assert len(all_indexed) == 7
|
||||
assert [c.chunk_id for c in all_indexed] == list(range(7))
|
||||
390
backend/tests/unit/onyx/indexing/test_embed_chunks_in_batches.py
Normal file
390
backend/tests/unit/onyx/indexing/test_embed_chunks_in_batches.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""Unit tests for _embed_chunks_to_store.
|
||||
|
||||
Tests cover:
|
||||
- Single batch, no failures
|
||||
- Multiple batches, no failures
|
||||
- Failure in a single batch
|
||||
- Cross-batch document failure scrubbing
|
||||
- Later batches skip already-failed docs
|
||||
- Empty input
|
||||
- All chunks fail
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.indexing_pipeline import _embed_chunks_to_store
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_doc(doc_id: str) -> Document:
|
||||
return Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test",
|
||||
source=DocumentSource.FILE,
|
||||
sections=[TextSection(text="test", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def _make_chunk(doc_id: str, chunk_id: int) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=_make_doc(doc_id),
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
def _make_index_chunk(doc_id: str, chunk_id: int) -> IndexChunk:
|
||||
"""Create an IndexChunk (a DocAwareChunk with embeddings)."""
|
||||
return IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=_make_doc(doc_id),
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_failure(doc_id: str) -> ConnectorFailure:
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(document_id=doc_id, document_link=None),
|
||||
failure_message="embedding failed",
|
||||
exception=RuntimeError("embedding failed"),
|
||||
)
|
||||
|
||||
|
||||
def _mock_embed_success(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
"""Simulate successful embedding of all chunks."""
|
||||
return (
|
||||
[_make_index_chunk(c.source_document.id, c.chunk_id) for c in chunks],
|
||||
[],
|
||||
)
|
||||
|
||||
|
||||
def _mock_embed_fail_doc(
|
||||
fail_doc_id: str,
|
||||
) -> "callable":
|
||||
"""Return an embed mock that fails all chunks for a specific doc."""
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
successes = [
|
||||
_make_index_chunk(c.source_document.id, c.chunk_id)
|
||||
for c in chunks
|
||||
if c.source_document.id != fail_doc_id
|
||||
]
|
||||
failures = (
|
||||
[_make_failure(fail_doc_id)]
|
||||
if any(c.source_document.id == fail_doc_id for c in chunks)
|
||||
else []
|
||||
)
|
||||
return successes, failures
|
||||
|
||||
return _embed
|
||||
|
||||
|
||||
class TestEmbedChunksInBatches:
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_single_batch_no_failures(self, mock_embed: MagicMock) -> None:
|
||||
"""All chunks fit in one batch and embed successfully."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", i) for i in range(3)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 3
|
||||
assert len(result.connector_failures) == 0
|
||||
|
||||
# Verify stored contents
|
||||
assert len(store._batch_files()) == 1
|
||||
stored = list(store.stream())
|
||||
assert len(stored) == 3
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_multiple_batches_no_failures(self, mock_embed: MagicMock) -> None:
|
||||
"""Chunks are split across multiple batches, all succeed."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", i) for i in range(7)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 7
|
||||
assert len(result.connector_failures) == 0
|
||||
assert len(store._batch_files()) == 3 # 3 + 3 + 1
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_single_batch_with_failure(self, mock_embed: MagicMock) -> None:
|
||||
"""One doc fails embedding, its chunks are excluded from results."""
|
||||
mock_embed.side_effect = _mock_embed_fail_doc("doc2")
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("doc1", 0),
|
||||
_make_chunk("doc2", 1),
|
||||
_make_chunk("doc1", 2),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.connector_failures) == 1
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "doc2" not in successful_doc_ids
|
||||
assert "doc1" in successful_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_cross_batch_failure_scrubs_earlier_batch(
|
||||
self, mock_embed: MagicMock
|
||||
) -> None:
|
||||
"""Doc A spans batches 0 and 1. It succeeds in batch 0 but fails in
|
||||
batch 1. Its chunks should be scrubbed from batch 0's batch file."""
|
||||
call_count = 0
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _mock_embed_success(chunks)
|
||||
else:
|
||||
return _mock_embed_fail_doc("docA")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("docA", 0),
|
||||
_make_chunk("docA", 1),
|
||||
_make_chunk("docA", 2),
|
||||
_make_chunk("docA", 3),
|
||||
_make_chunk("docB", 0),
|
||||
_make_chunk("docB", 1),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# docA should be fully excluded from results
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "docA" not in successful_doc_ids
|
||||
assert "docB" in successful_doc_ids
|
||||
assert len(result.connector_failures) == 1
|
||||
|
||||
# Verify batch 0 was scrubbed of docA chunks
|
||||
all_stored = list(store.stream())
|
||||
stored_doc_ids = {c.source_document.id for c in all_stored}
|
||||
assert "docA" not in stored_doc_ids
|
||||
assert "docB" in stored_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_later_batch_skips_already_failed_doc(self, mock_embed: MagicMock) -> None:
|
||||
"""If docA fails in batch 0, its chunks in batch 1 are skipped
|
||||
entirely (never sent to the embedder)."""
|
||||
embedded_doc_ids: list[str] = []
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
for c in chunks:
|
||||
embedded_doc_ids.append(c.source_document.id)
|
||||
return _mock_embed_fail_doc("docA")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("docA", 0),
|
||||
_make_chunk("docA", 1),
|
||||
_make_chunk("docA", 2),
|
||||
_make_chunk("docA", 3),
|
||||
_make_chunk("docB", 0),
|
||||
_make_chunk("docB", 1),
|
||||
]
|
||||
_embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# docA should only appear in batch 0, not batch 1
|
||||
batch_1_doc_ids = embedded_doc_ids[3:]
|
||||
assert "docA" not in batch_1_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_failed_doc_skipped_in_later_batch_while_other_doc_succeeds(
|
||||
self, mock_embed: MagicMock
|
||||
) -> None:
|
||||
"""doc1 spans batches 0 and 1, doc2 only in batch 1. Batch 0 fails
|
||||
doc1. In batch 1, doc1 chunks should be skipped but doc2 chunks
|
||||
should still be embedded successfully."""
|
||||
embedded_chunks: list[list[str]] = []
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
embedded_chunks.append([c.source_document.id for c in chunks])
|
||||
return _mock_embed_fail_doc("doc1")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("doc1", 0),
|
||||
_make_chunk("doc1", 1),
|
||||
_make_chunk("doc1", 2),
|
||||
_make_chunk("doc1", 3),
|
||||
_make_chunk("doc2", 0),
|
||||
_make_chunk("doc2", 1),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# doc1 should be fully excluded, doc2 fully included
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "doc1" not in successful_doc_ids
|
||||
assert "doc2" in successful_doc_ids
|
||||
assert len(result.successful_chunk_ids) == 2 # doc2's 2 chunks
|
||||
|
||||
# Batch 1 should only contain doc2 (doc1 was filtered before embedding)
|
||||
assert len(embedded_chunks) == 2
|
||||
assert "doc1" not in embedded_chunks[1]
|
||||
assert embedded_chunks[1] == ["doc2", "doc2"]
|
||||
|
||||
# Verify on-disk state has no doc1 chunks
|
||||
all_stored = list(store.stream())
|
||||
assert all(c.source_document.id == "doc2" for c in all_stored)
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
def test_empty_input(self, mock_embed: MagicMock) -> None:
|
||||
"""Empty chunk list produces empty results."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=[],
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 0
|
||||
assert len(result.connector_failures) == 0
|
||||
mock_embed.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_all_chunks_fail(self, mock_embed: MagicMock) -> None:
|
||||
"""When all documents fail, results have no successful chunks."""
|
||||
|
||||
def _fail_all(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
doc_ids = {c.source_document.id for c in chunks}
|
||||
return [], [_make_failure(doc_id) for doc_id in doc_ids]
|
||||
|
||||
mock_embed.side_effect = _fail_all
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", 0), _make_chunk("doc2", 1)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 0
|
||||
assert len(result.connector_failures) == 2
|
||||
@@ -116,7 +116,7 @@ def _run_adapter_build(
|
||||
project_ids_map: dict[str, list[int]],
|
||||
persona_ids_map: dict[str, list[int]],
|
||||
) -> list[DocMetadataAwareIndexChunk]:
|
||||
"""Helper that runs UserFileIndexingAdapter.build_metadata_aware_chunks
|
||||
"""Helper that runs UserFileIndexingAdapter.prepare_enrichment + enrich_chunk
|
||||
with all external dependencies mocked."""
|
||||
from onyx.indexing.adapters.user_file_indexing_adapter import (
|
||||
UserFileIndexingAdapter,
|
||||
@@ -155,18 +155,16 @@ def _run_adapter_build(
|
||||
side_effect=Exception("no LLM in tests"),
|
||||
),
|
||||
):
|
||||
result = adapter.build_metadata_aware_chunks(
|
||||
chunks_with_embeddings=[chunk],
|
||||
chunk_content_scores=[1.0],
|
||||
tenant_id="test_tenant",
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id="test_tenant",
|
||||
chunks=[chunk],
|
||||
)
|
||||
|
||||
return result.chunks
|
||||
return [enricher.enrich_chunk(chunk, 1.0)]
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
|
||||
def test_prepare_enrichment_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.prepare_enrichment writes persona IDs
|
||||
fetched from the DB into each chunk's metadata."""
|
||||
file_id = str(uuid4())
|
||||
persona_ids = [5, 12]
|
||||
@@ -183,7 +181,7 @@ def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
assert chunks[0].user_project == project_ids
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
|
||||
def test_prepare_enrichment_missing_file_defaults_to_empty() -> None:
|
||||
"""When a file has no persona or project associations in the DB, the
|
||||
adapter should default to empty lists (not KeyError or None)."""
|
||||
file_id = str(uuid4())
|
||||
|
||||
Reference in New Issue
Block a user