1
0
forked from github/onyx

Compare commits

...

6 Commits

Author SHA1 Message Date
Evan Lohn
43273eeb5a remove prints 2025-04-02 21:29:29 -07:00
Evan Lohn
27c22c5409 fix unit tests and minor pruning bug 2025-04-02 21:21:48 -07:00
evan-danswer
71938c643b comments 2025-04-02 15:07:25 -07:00
evan-danswer
c6c11cf04f comments 2025-04-02 15:05:20 -07:00
Evan Lohn
bbb8f5e9e2 better approach to length restriction 2025-04-02 13:07:08 -07:00
Evan Lohn
700c114e09 fix large docs selected in chat pruning 2025-04-02 12:31:56 -07:00
4 changed files with 63 additions and 29 deletions

View File

@@ -43,6 +43,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_me
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
@@ -692,8 +693,13 @@ def stream_chat_message_objects(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
# Add a maximum context size in the case of user-selected docs to prevent
# slight inaccuracies in context window size pruning from causing
# the entire query to fail
document_pruning_config = DocumentPruningConfig(
is_manually_selected_docs=True
is_manually_selected_docs=True,
max_window_percentage=SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE,
)
# In case the search doc is deleted, just don't include it

View File

@@ -312,11 +312,14 @@ def prune_sections(
)
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, int]:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
ADJACENT_CHUNK_SEP = "\n"
DISTANT_CHUNK_SEP = "\n\n...\n\n"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
@@ -324,33 +327,48 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
added_chars = 0
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
sep = (
ADJACENT_CHUNK_SEP
if chunk.chunk_id == prev_chunk_id + 1
else DISTANT_CHUNK_SEP
)
merged_content.append(sep)
added_chars += len(sep)
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
return (
InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
),
added_chars,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
combined_section_lengths: dict[str, int] = defaultdict(lambda: 0)
# chunk de-duping and doc ordering
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
combined_section_lengths[section.center_chunk.document_id] += len(
section.combined_content
)
chunks_map = docs_map[section.center_chunk.document_id]
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
@@ -361,8 +379,22 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
for doc_id, section_chunks in docs_map.items():
section_chunks_list = list(section_chunks.values())
merged_section, added_chars = _merge_doc_chunks(chunks=section_chunks_list)
previous_length = combined_section_lengths[doc_id] + added_chars
# After merging, ensure the content respects the pruning done earlier. Each
# combined section is restricted to the sum of the lengths of the sections
# from the pruning step. Technically the correct approach would be to prune based
# on tokens AGAIN, but this is a good approximation and worth not adding the
# tokenization overhead. This could also be fixed if we added a way of removing
# chunks from sections in the pruning step; at the moment this issue largely
# exists because we only trim the final section's combined_content.
merged_section.combined_content = merged_section.combined_content[
:previous_length
]
new_sections.append(merged_section)
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score

View File

@@ -16,6 +16,9 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# Maximum percentage of the context window to fill with selected sections
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(

View File

@@ -4,6 +4,7 @@ from onyx.chat.prune_and_merge import _merge_sections
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.utils import inference_section_from_chunks
# This large test accounts for all of the following:
@@ -111,7 +112,7 @@ Content 17
# Sections
[
# Document 1, top/middle/bot connected + disconnected section
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[
DOC_1_FILLER_1,
@@ -120,9 +121,8 @@ Content 17
DOC_1_MID_CHUNK,
DOC_1_FILLER_3,
],
combined_content="N/A", # Not used
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_MID_CHUNK,
chunks=[
DOC_1_FILLER_2,
@@ -131,9 +131,8 @@ Content 17
DOC_1_FILLER_3,
DOC_1_FILLER_4,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_BOTTOM_CHUNK,
chunks=[
DOC_1_FILLER_3,
@@ -142,9 +141,8 @@ Content 17
DOC_1_FILLER_5,
DOC_1_FILLER_6,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_DISCONNECTED,
chunks=[
DOC_1_FILLER_7,
@@ -153,9 +151,8 @@ Content 17
DOC_1_FILLER_9,
DOC_1_FILLER_10,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_2_TOP_CHUNK,
chunks=[
DOC_2_FILLER_1,
@@ -164,9 +161,8 @@ Content 17
DOC_2_FILLER_3,
DOC_2_BOTTOM_CHUNK,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_2_BOTTOM_CHUNK,
chunks=[
DOC_2_TOP_CHUNK,
@@ -175,7 +171,6 @@ Content 17
DOC_2_FILLER_4,
DOC_2_FILLER_5,
],
combined_content="N/A",
),
],
# Expected Content
@@ -204,15 +199,13 @@ def test_merge_sections(
(
# Sections
[
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[DOC_1_TOP_CHUNK],
combined_content="N/A", # Not used
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_MID_CHUNK,
chunks=[DOC_1_MID_CHUNK],
combined_content="N/A",
),
],
# Expected Content