mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-30 03:52:42 +00:00
Compare commits
14 Commits
dane/vecto
...
dane/index
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b5e30b829 | ||
|
|
a307900e4f | ||
|
|
f92a3c3d60 | ||
|
|
a33afcf912 | ||
|
|
c730850e81 | ||
|
|
b671bf4d4e | ||
|
|
6b7d9f4cfb | ||
|
|
399e251d85 | ||
|
|
6c38a28cf6 | ||
|
|
93d2f6d552 | ||
|
|
9124a6110d | ||
|
|
26850a42b3 | ||
|
|
c837d1ba80 | ||
|
|
fac7887542 |
88
backend/onyx/indexing/chunk_batch_store.py
Normal file
88
backend/onyx/indexing/chunk_batch_store.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import pickle
|
||||
import tempfile
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
class ChunkBatchStore:
|
||||
"""Manages serialization of embedded chunks to a temporary directory.
|
||||
|
||||
Owns the temp directory lifetime and provides save/load/stream/scrub
|
||||
operations.
|
||||
|
||||
Use as a context manager to ensure cleanup::
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
store.save(chunks, batch_idx=0)
|
||||
for chunk in store.stream():
|
||||
...
|
||||
"""
|
||||
|
||||
_EXT = ".pkl"
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._tmpdir_ctx = tempfile.TemporaryDirectory(prefix="onyx_embeddings_")
|
||||
self._tmpdir: Path | None = None
|
||||
|
||||
# -- context manager -----------------------------------------------------
|
||||
|
||||
def __enter__(self) -> "ChunkBatchStore":
|
||||
self._tmpdir = Path(self._tmpdir_ctx.__enter__())
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc: object) -> None:
|
||||
self._tmpdir_ctx.__exit__(*exc)
|
||||
self._tmpdir = None
|
||||
|
||||
@property
|
||||
def _dir(self) -> Path:
|
||||
assert self._tmpdir is not None, "ChunkBatchStore used outside context manager"
|
||||
return self._tmpdir
|
||||
|
||||
# -- storage primitives --------------------------------------------------
|
||||
|
||||
def save(self, chunks: list[IndexChunk], batch_idx: int) -> None:
|
||||
"""Serialize a batch of embedded chunks to disk."""
|
||||
with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f:
|
||||
pickle.dump(chunks, f)
|
||||
|
||||
def _load(self, batch_file: Path) -> list[IndexChunk]:
|
||||
"""Deserialize a batch of embedded chunks from a file."""
|
||||
with open(batch_file, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def _batch_files(self) -> list[Path]:
|
||||
"""Return batch files sorted by numeric index."""
|
||||
return sorted(
|
||||
self._dir.glob(f"batch_*{self._EXT}"),
|
||||
key=lambda p: int(p.stem.removeprefix("batch_")),
|
||||
)
|
||||
|
||||
# -- higher-level operations ---------------------------------------------
|
||||
|
||||
def stream(self) -> Iterator[IndexChunk]:
|
||||
"""Yield all chunks across all batch files.
|
||||
|
||||
Each call returns a fresh generator, so the data can be iterated
|
||||
multiple times (e.g. once per document index).
|
||||
"""
|
||||
for batch_file in self._batch_files():
|
||||
yield from self._load(batch_file)
|
||||
|
||||
def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None:
|
||||
"""Remove chunks belonging to *failed_doc_ids* from all batch files.
|
||||
|
||||
When a document fails embedding in batch N, earlier batches may
|
||||
already contain successfully embedded chunks for that document.
|
||||
This ensures the output is all-or-nothing per document.
|
||||
"""
|
||||
for batch_file in self._batch_files():
|
||||
batch_chunks = self._load(batch_file)
|
||||
cleaned = [
|
||||
c for c in batch_chunks if c.source_document.id not in failed_doc_ids
|
||||
]
|
||||
if len(cleaned) != len(batch_chunks):
|
||||
with open(batch_file, "wb") as f:
|
||||
pickle.dump(cleaned, f)
|
||||
@@ -1,7 +1,8 @@
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterable
|
||||
from typing import cast
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -11,6 +12,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME
|
||||
from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER
|
||||
from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG
|
||||
from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION
|
||||
from onyx.configs.app_configs import USE_CHUNK_SUMMARY
|
||||
@@ -45,6 +47,7 @@ from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
@@ -66,6 +69,7 @@ from onyx.natural_language_processing.utils import tokenizer_trim_middle
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1
|
||||
from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2
|
||||
from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -104,6 +108,11 @@ class IndexingPipelineResult(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ChunkEmbeddingResult(BaseModel):
|
||||
successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id)
|
||||
connector_failures: list[ConnectorFailure]
|
||||
|
||||
|
||||
class IndexingPipelineProtocol(Protocol):
|
||||
def __call__(
|
||||
self,
|
||||
@@ -151,6 +160,110 @@ def _upsert_documents_in_db(
|
||||
)
|
||||
|
||||
|
||||
def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]:
|
||||
"""Extract document IDs from a list of connector failures."""
|
||||
return {f.failed_document.document_id for f in failures if f.failed_document}
|
||||
|
||||
|
||||
def _embed_chunks_to_store(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str,
|
||||
request_id: str | None,
|
||||
store: ChunkBatchStore,
|
||||
) -> ChunkEmbeddingResult:
|
||||
"""Embed chunks in batches, spilling each batch to *store*.
|
||||
|
||||
If a document fails embedding in any batch, its chunks are excluded from
|
||||
all batches (including earlier ones already written) so that the output
|
||||
is all-or-nothing per document.
|
||||
"""
|
||||
successful_chunk_ids: list[tuple[int, str]] = []
|
||||
all_embedding_failures: list[ConnectorFailure] = []
|
||||
# Track failed doc IDs across all batches so that a failure in batch N
|
||||
# causes chunks for that doc to be skipped in batch N+1 and stripped
|
||||
# from earlier batches.
|
||||
all_failed_doc_ids: set[str] = set()
|
||||
|
||||
for batch_idx, chunk_batch in enumerate(
|
||||
batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH)
|
||||
):
|
||||
# Skip chunks belonging to documents that failed in earlier batches.
|
||||
chunk_batch = [
|
||||
c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids
|
||||
]
|
||||
if not chunk_batch:
|
||||
continue
|
||||
|
||||
logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks")
|
||||
|
||||
chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling(
|
||||
chunks=chunk_batch,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
all_embedding_failures.extend(embedding_failures)
|
||||
all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures))
|
||||
|
||||
# Only keep successfully embedded chunks for non-failed docs.
|
||||
chunks_with_embeddings = [
|
||||
c
|
||||
for c in chunks_with_embeddings
|
||||
if c.source_document.id not in all_failed_doc_ids
|
||||
]
|
||||
|
||||
successful_chunk_ids.extend(
|
||||
(c.chunk_id, c.source_document.id) for c in chunks_with_embeddings
|
||||
)
|
||||
|
||||
store.save(chunks_with_embeddings, batch_idx)
|
||||
del chunks_with_embeddings
|
||||
|
||||
# Scrub earlier batches for docs that failed in later batches.
|
||||
if all_failed_doc_ids:
|
||||
store.scrub_failed_docs(all_failed_doc_ids)
|
||||
successful_chunk_ids = [
|
||||
(chunk_id, doc_id)
|
||||
for chunk_id, doc_id in successful_chunk_ids
|
||||
if doc_id not in all_failed_doc_ids
|
||||
]
|
||||
|
||||
return ChunkEmbeddingResult(
|
||||
successful_chunk_ids=successful_chunk_ids,
|
||||
connector_failures=all_embedding_failures,
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def embed_and_stream(
|
||||
chunks: list[DocAwareChunk],
|
||||
embedder: IndexingEmbedder,
|
||||
tenant_id: str,
|
||||
request_id: str | None,
|
||||
) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]:
|
||||
"""Embed chunks to disk and yield a ``(result, store)`` pair.
|
||||
|
||||
The store owns the temp directory — files are cleaned up when the context
|
||||
manager exits.
|
||||
|
||||
Usage::
|
||||
|
||||
with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store):
|
||||
for chunk in store.stream():
|
||||
...
|
||||
"""
|
||||
with ChunkBatchStore() as store:
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
store=store,
|
||||
)
|
||||
yield result, store
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
@@ -649,6 +762,29 @@ def add_contextual_summaries(
|
||||
return chunks
|
||||
|
||||
|
||||
def _verify_indexing_completeness(
|
||||
insertion_records: list[DocumentInsertionRecord],
|
||||
write_failures: list[ConnectorFailure],
|
||||
embedding_failed_doc_ids: set[str],
|
||||
updatable_ids: list[str],
|
||||
document_index_name: str,
|
||||
) -> None:
|
||||
"""Verify that every updatable document was either indexed or reported as failed."""
|
||||
all_returned_doc_ids = (
|
||||
{r.document_id for r in insertion_records}
|
||||
| {f.failed_document.document_id for f in write_failures if f.failed_document}
|
||||
| embedding_failed_doc_ids
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
f"This should never happen. "
|
||||
f"This occured for document index {document_index_name}"
|
||||
)
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
@@ -723,127 +859,99 @@ def index_doc_batch(
|
||||
)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
embed_chunks_with_failure_handling(
|
||||
chunks=chunks,
|
||||
embedder=embedder,
|
||||
tenant_id=tenant_id,
|
||||
request_id=request_id,
|
||||
)
|
||||
if chunks
|
||||
else ([], [])
|
||||
)
|
||||
|
||||
chunk_content_scores = [1.0] * len(chunks_with_embeddings)
|
||||
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk.chunk_id,
|
||||
document_id=chunk.source_document.id,
|
||||
boost_score=score,
|
||||
)
|
||||
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
|
||||
]
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
# with Vespa can occur.
|
||||
with adapter.lock_context(context.updatable_docs):
|
||||
# we're concerned about race conditions where multiple simultaneous indexings might result
|
||||
# in one set of metadata overwriting another one in vespa.
|
||||
# we still write data here for the immediate and most likely correct sync, but
|
||||
# to resolve this, an update of the last modified field at the end of this loop
|
||||
# always triggers a final metadata sync via the celery queue
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=tenant_id,
|
||||
chunks=cast(list[DocAwareChunk], chunks_with_embeddings),
|
||||
)
|
||||
|
||||
metadata_aware_chunks = [
|
||||
enricher.enrich_chunk(chunk, score)
|
||||
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
|
||||
with embed_and_stream(chunks, embedder, tenant_id, request_id) as (
|
||||
embedding_result,
|
||||
chunk_store,
|
||||
):
|
||||
updatable_ids = [doc.id for doc in context.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk_id,
|
||||
document_id=document_id,
|
||||
boost_score=1.0,
|
||||
)
|
||||
for chunk_id, document_id in embedding_result.successful_chunk_ids
|
||||
]
|
||||
|
||||
short_descriptor_list = [
|
||||
chunk.to_short_descriptor() for chunk in metadata_aware_chunks
|
||||
]
|
||||
short_descriptor_log = str(short_descriptor_list)[:1024]
|
||||
logger.debug(f"Indexing the following chunks: {short_descriptor_log}")
|
||||
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = None
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = None
|
||||
|
||||
def chunk_iterable_creator() -> Iterable[DocMetadataAwareIndexChunk]:
|
||||
return metadata_aware_chunks
|
||||
|
||||
for document_index in document_indices:
|
||||
# A document will not be spread across different batches, so all the
|
||||
# documents with chunks in this set, are fully represented by the chunks
|
||||
# in this set
|
||||
(
|
||||
insertion_records,
|
||||
vector_db_write_failures,
|
||||
) = write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
make_chunks=chunk_iterable_creator,
|
||||
index_batch_params=IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
),
|
||||
)
|
||||
|
||||
all_returned_doc_ids: set[str] = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
f"This occurred for document index {document_index.__class__.__name__}"
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = vector_db_write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
enrichment=enricher,
|
||||
embedding_failed_doc_ids = _get_failed_doc_ids(
|
||||
embedding_result.connector_failures
|
||||
)
|
||||
|
||||
# Filter to only successfully embedded chunks so
|
||||
# doc_id_to_new_chunk_cnt reflects what's actually written to Vespa.
|
||||
embedded_chunks = [
|
||||
c for c in chunks if c.source_document.id not in embedding_failed_doc_ids
|
||||
]
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify
|
||||
# them. Not needed until here, since this is when the actual race
|
||||
# condition with vector db can occur.
|
||||
with adapter.lock_context(context.updatable_docs):
|
||||
enricher = adapter.prepare_enrichment(
|
||||
context=context,
|
||||
tenant_id=tenant_id,
|
||||
chunks=embedded_chunks,
|
||||
)
|
||||
|
||||
index_batch_params = IndexBatchParams(
|
||||
doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt,
|
||||
doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt,
|
||||
tenant_id=tenant_id,
|
||||
large_chunks_enabled=chunker.enable_large_chunks,
|
||||
)
|
||||
|
||||
primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = (
|
||||
None
|
||||
)
|
||||
primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = (
|
||||
None
|
||||
)
|
||||
|
||||
for document_index in document_indices:
|
||||
|
||||
def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]:
|
||||
for chunk in chunk_store.stream():
|
||||
yield enricher.enrich_chunk(chunk, 1.0)
|
||||
|
||||
insertion_records, write_failures = (
|
||||
write_chunks_to_vector_db_with_backoff(
|
||||
document_index=document_index,
|
||||
make_chunks=_enriched_stream,
|
||||
index_batch_params=index_batch_params,
|
||||
)
|
||||
)
|
||||
|
||||
_verify_indexing_completeness(
|
||||
insertion_records=insertion_records,
|
||||
write_failures=write_failures,
|
||||
embedding_failed_doc_ids=embedding_failed_doc_ids,
|
||||
updatable_ids=updatable_ids,
|
||||
document_index_name=document_index.__class__.__name__,
|
||||
)
|
||||
# We treat the first document index we got as the primary one used
|
||||
# for reporting the state of indexing.
|
||||
if primary_doc_idx_insertion_records is None:
|
||||
primary_doc_idx_insertion_records = insertion_records
|
||||
if primary_doc_idx_vector_db_write_failures is None:
|
||||
primary_doc_idx_vector_db_write_failures = write_failures
|
||||
|
||||
adapter.post_index(
|
||||
context=context,
|
||||
updatable_chunk_data=updatable_chunk_data,
|
||||
filtered_documents=filtered_documents,
|
||||
enrichment=enricher,
|
||||
)
|
||||
|
||||
assert primary_doc_idx_insertion_records is not None
|
||||
assert primary_doc_idx_vector_db_write_failures is not None
|
||||
return IndexingPipelineResult(
|
||||
new_docs=len(
|
||||
[r for r in primary_doc_idx_insertion_records if not r.already_existed]
|
||||
new_docs=sum(
|
||||
1 for r in primary_doc_idx_insertion_records if not r.already_existed
|
||||
),
|
||||
total_docs=len(filtered_documents),
|
||||
total_chunks=len(chunks_with_embeddings),
|
||||
failures=primary_doc_idx_vector_db_write_failures + embedding_failures,
|
||||
total_chunks=len(embedding_result.successful_chunk_ids),
|
||||
failures=primary_doc_idx_vector_db_write_failures
|
||||
+ embedding_result.connector_failures,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
External dependency unit tests for UserFileIndexingAdapter metadata writing.
|
||||
|
||||
Validates that build_metadata_aware_chunks produces DocMetadataAwareIndexChunk
|
||||
Validates that prepare_enrichment produces DocMetadataAwareIndexChunk
|
||||
objects with both `user_project` and `personas` fields populated correctly
|
||||
based on actual DB associations.
|
||||
|
||||
@@ -127,7 +127,7 @@ def _make_index_chunk(user_file: UserFile) -> IndexChunk:
|
||||
|
||||
|
||||
class TestAdapterWritesBothMetadataFields:
|
||||
"""build_metadata_aware_chunks must populate user_project AND personas."""
|
||||
"""prepare_enrichment must populate user_project AND personas."""
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.adapters.user_file_indexing_adapter.get_default_llm",
|
||||
|
||||
390
backend/tests/unit/onyx/indexing/test_embed_chunks_in_batches.py
Normal file
390
backend/tests/unit/onyx/indexing/test_embed_chunks_in_batches.py
Normal file
@@ -0,0 +1,390 @@
|
||||
"""Unit tests for _embed_chunks_to_store.
|
||||
|
||||
Tests cover:
|
||||
- Single batch, no failures
|
||||
- Multiple batches, no failures
|
||||
- Failure in a single batch
|
||||
- Cross-batch document failure scrubbing
|
||||
- Later batches skip already-failed docs
|
||||
- Empty input
|
||||
- All chunks fail
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.chunk_batch_store import ChunkBatchStore
|
||||
from onyx.indexing.indexing_pipeline import _embed_chunks_to_store
|
||||
from onyx.indexing.models import ChunkEmbedding
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
|
||||
|
||||
def _make_doc(doc_id: str) -> Document:
|
||||
return Document(
|
||||
id=doc_id,
|
||||
semantic_identifier="test",
|
||||
source=DocumentSource.FILE,
|
||||
sections=[TextSection(text="test", link=None)],
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def _make_chunk(doc_id: str, chunk_id: int) -> DocAwareChunk:
|
||||
return DocAwareChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=_make_doc(doc_id),
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
)
|
||||
|
||||
|
||||
def _make_index_chunk(doc_id: str, chunk_id: int) -> IndexChunk:
|
||||
"""Create an IndexChunk (a DocAwareChunk with embeddings)."""
|
||||
return IndexChunk(
|
||||
chunk_id=chunk_id,
|
||||
blurb="test",
|
||||
content="test content",
|
||||
source_links=None,
|
||||
image_file_id=None,
|
||||
section_continuation=False,
|
||||
source_document=_make_doc(doc_id),
|
||||
title_prefix="",
|
||||
metadata_suffix_semantic="",
|
||||
metadata_suffix_keyword="",
|
||||
mini_chunk_texts=None,
|
||||
large_chunk_id=None,
|
||||
doc_summary="",
|
||||
chunk_context="",
|
||||
contextual_rag_reserved_tokens=0,
|
||||
embeddings=ChunkEmbedding(
|
||||
full_embedding=[0.1] * 10,
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_failure(doc_id: str) -> ConnectorFailure:
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(document_id=doc_id, document_link=None),
|
||||
failure_message="embedding failed",
|
||||
exception=RuntimeError("embedding failed"),
|
||||
)
|
||||
|
||||
|
||||
def _mock_embed_success(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
"""Simulate successful embedding of all chunks."""
|
||||
return (
|
||||
[_make_index_chunk(c.source_document.id, c.chunk_id) for c in chunks],
|
||||
[],
|
||||
)
|
||||
|
||||
|
||||
def _mock_embed_fail_doc(
|
||||
fail_doc_id: str,
|
||||
) -> "callable":
|
||||
"""Return an embed mock that fails all chunks for a specific doc."""
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
successes = [
|
||||
_make_index_chunk(c.source_document.id, c.chunk_id)
|
||||
for c in chunks
|
||||
if c.source_document.id != fail_doc_id
|
||||
]
|
||||
failures = (
|
||||
[_make_failure(fail_doc_id)]
|
||||
if any(c.source_document.id == fail_doc_id for c in chunks)
|
||||
else []
|
||||
)
|
||||
return successes, failures
|
||||
|
||||
return _embed
|
||||
|
||||
|
||||
class TestEmbedChunksInBatches:
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_single_batch_no_failures(self, mock_embed: MagicMock) -> None:
|
||||
"""All chunks fit in one batch and embed successfully."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", i) for i in range(3)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 3
|
||||
assert len(result.connector_failures) == 0
|
||||
|
||||
# Verify stored contents
|
||||
assert len(store._batch_files()) == 1
|
||||
stored = list(store.stream())
|
||||
assert len(stored) == 3
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_multiple_batches_no_failures(self, mock_embed: MagicMock) -> None:
|
||||
"""Chunks are split across multiple batches, all succeed."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", i) for i in range(7)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 7
|
||||
assert len(result.connector_failures) == 0
|
||||
assert len(store._batch_files()) == 3 # 3 + 3 + 1
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_single_batch_with_failure(self, mock_embed: MagicMock) -> None:
|
||||
"""One doc fails embedding, its chunks are excluded from results."""
|
||||
mock_embed.side_effect = _mock_embed_fail_doc("doc2")
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("doc1", 0),
|
||||
_make_chunk("doc2", 1),
|
||||
_make_chunk("doc1", 2),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.connector_failures) == 1
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "doc2" not in successful_doc_ids
|
||||
assert "doc1" in successful_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_cross_batch_failure_scrubs_earlier_batch(
|
||||
self, mock_embed: MagicMock
|
||||
) -> None:
|
||||
"""Doc A spans batches 0 and 1. It succeeds in batch 0 but fails in
|
||||
batch 1. Its chunks should be scrubbed from batch 0's batch file."""
|
||||
call_count = 0
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return _mock_embed_success(chunks)
|
||||
else:
|
||||
return _mock_embed_fail_doc("docA")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("docA", 0),
|
||||
_make_chunk("docA", 1),
|
||||
_make_chunk("docA", 2),
|
||||
_make_chunk("docA", 3),
|
||||
_make_chunk("docB", 0),
|
||||
_make_chunk("docB", 1),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# docA should be fully excluded from results
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "docA" not in successful_doc_ids
|
||||
assert "docB" in successful_doc_ids
|
||||
assert len(result.connector_failures) == 1
|
||||
|
||||
# Verify batch 0 was scrubbed of docA chunks
|
||||
all_stored = list(store.stream())
|
||||
stored_doc_ids = {c.source_document.id for c in all_stored}
|
||||
assert "docA" not in stored_doc_ids
|
||||
assert "docB" in stored_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_later_batch_skips_already_failed_doc(self, mock_embed: MagicMock) -> None:
|
||||
"""If docA fails in batch 0, its chunks in batch 1 are skipped
|
||||
entirely (never sent to the embedder)."""
|
||||
embedded_doc_ids: list[str] = []
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
for c in chunks:
|
||||
embedded_doc_ids.append(c.source_document.id)
|
||||
return _mock_embed_fail_doc("docA")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("docA", 0),
|
||||
_make_chunk("docA", 1),
|
||||
_make_chunk("docA", 2),
|
||||
_make_chunk("docA", 3),
|
||||
_make_chunk("docB", 0),
|
||||
_make_chunk("docB", 1),
|
||||
]
|
||||
_embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# docA should only appear in batch 0, not batch 1
|
||||
batch_1_doc_ids = embedded_doc_ids[3:]
|
||||
assert "docA" not in batch_1_doc_ids
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 3)
|
||||
def test_failed_doc_skipped_in_later_batch_while_other_doc_succeeds(
|
||||
self, mock_embed: MagicMock
|
||||
) -> None:
|
||||
"""doc1 spans batches 0 and 1, doc2 only in batch 1. Batch 0 fails
|
||||
doc1. In batch 1, doc1 chunks should be skipped but doc2 chunks
|
||||
should still be embedded successfully."""
|
||||
embedded_chunks: list[list[str]] = []
|
||||
|
||||
def _embed(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
embedded_chunks.append([c.source_document.id for c in chunks])
|
||||
return _mock_embed_fail_doc("doc1")(chunks)
|
||||
|
||||
mock_embed.side_effect = _embed
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [
|
||||
_make_chunk("doc1", 0),
|
||||
_make_chunk("doc1", 1),
|
||||
_make_chunk("doc1", 2),
|
||||
_make_chunk("doc1", 3),
|
||||
_make_chunk("doc2", 0),
|
||||
_make_chunk("doc2", 1),
|
||||
]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
# doc1 should be fully excluded, doc2 fully included
|
||||
successful_doc_ids = {doc_id for _, doc_id in result.successful_chunk_ids}
|
||||
assert "doc1" not in successful_doc_ids
|
||||
assert "doc2" in successful_doc_ids
|
||||
assert len(result.successful_chunk_ids) == 2 # doc2's 2 chunks
|
||||
|
||||
# Batch 1 should only contain doc2 (doc1 was filtered before embedding)
|
||||
assert len(embedded_chunks) == 2
|
||||
assert "doc1" not in embedded_chunks[1]
|
||||
assert embedded_chunks[1] == ["doc2", "doc2"]
|
||||
|
||||
# Verify on-disk state has no doc1 chunks
|
||||
all_stored = list(store.stream())
|
||||
assert all(c.source_document.id == "doc2" for c in all_stored)
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
def test_empty_input(self, mock_embed: MagicMock) -> None:
|
||||
"""Empty chunk list produces empty results."""
|
||||
mock_embed.side_effect = _mock_embed_success
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=[],
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 0
|
||||
assert len(result.connector_failures) == 0
|
||||
mock_embed.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"onyx.indexing.indexing_pipeline.embed_chunks_with_failure_handling",
|
||||
)
|
||||
@patch("onyx.indexing.indexing_pipeline.MAX_CHUNKS_PER_DOC_BATCH", 100)
|
||||
def test_all_chunks_fail(self, mock_embed: MagicMock) -> None:
|
||||
"""When all documents fail, results have no successful chunks."""
|
||||
|
||||
def _fail_all(
|
||||
chunks: list[DocAwareChunk], **_kwargs: object
|
||||
) -> tuple[list[IndexChunk], list[ConnectorFailure]]:
|
||||
doc_ids = {c.source_document.id for c in chunks}
|
||||
return [], [_make_failure(doc_id) for doc_id in doc_ids]
|
||||
|
||||
mock_embed.side_effect = _fail_all
|
||||
|
||||
with ChunkBatchStore() as store:
|
||||
chunks = [_make_chunk("doc1", 0), _make_chunk("doc2", 1)]
|
||||
result = _embed_chunks_to_store(
|
||||
chunks=chunks,
|
||||
embedder=MagicMock(),
|
||||
tenant_id="test",
|
||||
request_id=None,
|
||||
store=store,
|
||||
)
|
||||
|
||||
assert len(result.successful_chunk_ids) == 0
|
||||
assert len(result.connector_failures) == 2
|
||||
@@ -163,8 +163,8 @@ def _run_adapter_build(
|
||||
return [enricher.enrich_chunk(chunk, 1.0)]
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.build_metadata_aware_chunks writes persona IDs
|
||||
def test_prepare_enrichment_includes_persona_ids() -> None:
|
||||
"""UserFileIndexingAdapter.prepare_enrichment writes persona IDs
|
||||
fetched from the DB into each chunk's metadata."""
|
||||
file_id = str(uuid4())
|
||||
persona_ids = [5, 12]
|
||||
@@ -181,7 +181,7 @@ def test_build_metadata_aware_chunks_includes_persona_ids() -> None:
|
||||
assert chunks[0].user_project == project_ids
|
||||
|
||||
|
||||
def test_build_metadata_aware_chunks_missing_file_defaults_to_empty() -> None:
|
||||
def test_prepare_enrichment_missing_file_defaults_to_empty() -> None:
|
||||
"""When a file has no persona or project associations in the DB, the
|
||||
adapter should default to empty lists (not KeyError or None)."""
|
||||
file_id = str(uuid4())
|
||||
|
||||
Reference in New Issue
Block a user