1
0
forked from github/onyx

Compare commits

...

12 Commits

Author SHA1 Message Date
Evan Lohn
bfcf39bff7 final score 28.4 2025-07-01 20:52:11 -07:00
Evan Lohn
aa4a10e016 oops 2025-07-01 18:42:16 -07:00
Evan Lohn
19292aff65 v0.1 2025-06-30 17:25:05 -07:00
Evan Lohn
9a2f4e86dd WIP 2025-06-30 17:25:05 -07:00
Evan Lohn
6d835f7808 sys prompt has sources 2025-06-30 17:25:03 -07:00
joachim-danswer
013bed3157 fix 2025-06-30 15:19:42 -07:00
joachim-danswer
289f27c43a updates 2025-06-30 15:06:12 -07:00
joachim-danswer
736a9bd332 erase history 2025-06-30 09:01:23 -07:00
joachim-danswer
8bcad415bb nit 2025-06-30 08:16:43 -07:00
joachim-danswer
93e6e4a089 mypy nits 2025-06-30 07:49:55 -07:00
joachim-danswer
ed0062dce0 fix 2025-06-30 02:45:03 -07:00
joachim-danswer
6e8bf3120c hackathon v1 changes 2025-06-30 01:39:36 -07:00
14 changed files with 535 additions and 59 deletions

View File

@@ -5,6 +5,7 @@ from pydantic import model_validator
from sqlalchemy.orm import Session
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import RerankingDetails
from onyx.db.models import Persona
from onyx.file_store.utils import InMemoryChatFile
@@ -71,6 +72,7 @@ class GraphSearchConfig(BaseModel):
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
gen_excerpts: bool = True
kg_config_settings: KGConfigSettings = KGConfigSettings()
@@ -93,3 +95,10 @@ class GraphConfig(BaseModel):
class Config:
arbitrary_types_allowed = True
class GeneratedExcerpt(BaseModel):
"""A generated excerpt from a document"""
excerpt: str
document_source: DocumentSource

View File

@@ -1,3 +1,4 @@
import json
from typing import cast
from uuid import uuid4
@@ -22,6 +23,9 @@ from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEAR
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.llm.factory import get_default_llms
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.prompts.chat_prompts import GEN_EXCERPT_PROMPT
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
@@ -37,6 +41,7 @@ from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -51,9 +56,21 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)
HISTORY_PROMPT_MAP = {
(QueryExpansionType.KEYWORD, True): QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT,
(QueryExpansionType.KEYWORD, False): QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT,
(QueryExpansionType.SEMANTIC, True): QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT,
(
QueryExpansionType.SEMANTIC,
False,
): QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT,
}
def _expand_query(
query: str,
expansion_type: QueryExpansionType,
@@ -62,18 +79,11 @@ def _expand_query(
history_str = _create_history_str(prompt_builder)
base_prompt = HISTORY_PROMPT_MAP[(expansion_type, bool(history_str))]
format_args = {"question": query}
if history_str:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query, history=history_str)
else:
if expansion_type == QueryExpansionType.KEYWORD:
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
else:
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
expansion_prompt = base_prompt.format(question=query)
format_args["history"] = history_str
expansion_prompt = base_prompt.format(**format_args)
msg = HumanMessage(content=expansion_prompt)
primary_llm, _ = get_default_llms()
@@ -83,15 +93,34 @@ def _expand_query(
return rephrased_query
# TODO: right now the llm describes the places where stuff could be, I want it to hallucinate real content
# fine for now since I'm just tryna get the pipeline going
def _gen_excerpts(
prompt_builder: AnswerPromptBuilder,
llm: LLM,
) -> list[str]:
user_prompt = prompt_builder.build()
prompt_str = GEN_EXCERPT_PROMPT.format(prompt=user_prompt)
excerpts_str = message_to_string(llm.invoke(prompt_str))
try:
excerpts = json.loads(excerpts_str)
except json.JSONDecodeError:
excerpts = [
exc for exc in excerpts_str.split('"') if len(exc) > 5
] # TODO: jank
if not isinstance(excerpts, list):
raise ValueError(f"Excerpts is not a list: {excerpts}")
return excerpts
def _expand_query_non_tool_calling_llm(
expanded_keyword_thread: TimeoutThread[str],
expanded_semantic_thread: TimeoutThread[str],
) -> QueryExpansions | None:
keyword_expansion: str | None = wait_on_background(expanded_keyword_thread)
semantic_expansion: str | None = wait_on_background(expanded_semantic_thread)
if keyword_expansion is None or semantic_expansion is None:
return None
keyword_expansion: str = wait_on_background(expanded_keyword_thread)
semantic_expansion: str = wait_on_background(expanded_semantic_thread)
return QueryExpansions(
keywords_expansions=[keyword_expansion],
@@ -123,6 +152,7 @@ def choose_tool(
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
expanded_keyword_thread: TimeoutThread[str] | None = None
expanded_semantic_thread: TimeoutThread[str] | None = None
gen_excerpts_thread: TimeoutThread[list[str]] | None = None
# If we have override_kwargs, add them to the tool_args
override_kwargs: SearchToolOverrideKwargs = (
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
@@ -166,6 +196,12 @@ def choose_tool(
QueryExpansionType.SEMANTIC,
prompt_builder,
)
if agent_config.behavior.gen_excerpts:
gen_excerpts_thread = run_in_background(
_gen_excerpts,
agent_config.inputs.prompt_builder,
llm,
)
structured_response_format = agent_config.inputs.structured_response_format
tools = [
@@ -224,6 +260,10 @@ def choose_tool(
):
raise ValueError("No expanded keyword or semantic threads found.")
if gen_excerpts_thread:
excerpts = wait_on_background(gen_excerpts_thread)
override_kwargs.gen_excerpts = excerpts
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
@@ -339,6 +379,10 @@ def choose_tool(
):
raise ValueError("No expanded keyword or semantic threads found.")
if gen_excerpts_thread:
excerpts = wait_on_background(gen_excerpts_thread)
override_kwargs.gen_excerpts = excerpts
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,

View File

@@ -460,14 +460,14 @@ def should_index(
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
# )
return False
else:
if (
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
): # Ingestion API
# print(
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
# )
return False
# else:
# if (
# connector.source == DocumentSource.INGESTION_API
# ): # Ingestion API
# # print(
# # f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
# # )
# return False
return True
# If the connector is paused or is the ingestion API, don't index

View File

@@ -1,7 +1,12 @@
import csv
import json
import os
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import GraphConfig
@@ -11,6 +16,9 @@ from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_agent_search_graph
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import (
run_basic_graph as run_hackathon_graph,
) # You can create your own graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.chat.models import AgentAnswerPiece
@@ -22,9 +30,11 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.models import ToolCallFinalResult
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import RerankingDetails
@@ -44,6 +54,190 @@ logger = setup_logger()
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
"""
Calculate the score for a given position.
"""
if pos > max_acceptable_pos:
return 0
elif pos == 1:
return 1
elif pos == 2:
return 0.8
else:
return 4 / (pos + 5)
def _clean_doc_id_link(doc_link: str) -> str:
"""
Clean the google doc link.
"""
if "google.com" in doc_link:
if "/edit" in doc_link:
return "/edit".join(doc_link.split("/edit")[:-1])
elif "/view" in doc_link:
return "/view".join(doc_link.split("/view")[:-1])
else:
return doc_link
if "app.fireflies.ai" in doc_link:
return "?".join(doc_link.split("?")[:-1])
return doc_link
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
"""
Get the score of a document from the document results.
"""
match_pos = None
for pos, comp_doc in enumerate(doc_results, start=1):
clear_doc_id = _clean_doc_id_link(doc_id)
clear_comp_doc = _clean_doc_id_link(comp_doc)
if clear_doc_id == clear_comp_doc:
match_pos = pos
if match_pos is None:
return 0.0
return _calc_score_for_pos(match_pos)
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH) -> None:
"""
Append an empty line to the CSV file.
"""
_append_answer_to_csv("", "", csv_path)
def _append_ground_truth_to_csv(
query: str,
ground_truth_docs: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for doc_id in ground_truth_docs:
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
logger.debug("Appended score to csv file")
def _append_score_to_csv(
query: str,
score: float,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", "", score])
logger.debug("Appended score to csv file")
def _append_search_results_to_csv(
query: str,
doc_results: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the search results to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for pos, doc in enumerate(doc_results, start=1):
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
logger.debug("Appended search results to csv file")
def _append_answer_to_csv(
query: str,
answer: str,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", answer, ""])
logger.debug("Appended answer to csv file")
class Answer:
def __init__(
self,
@@ -134,6 +328,9 @@ class Answer:
@property
def processed_streamed_output(self) -> AnswerStream:
_HACKATHON_TEST_EXECUTION = False
if self._processed_stream is not None:
yield from self._processed_stream
return
@@ -154,22 +351,118 @@ class Answer:
)
):
run_langgraph = run_dc_graph
elif (
self.graph_config.inputs.persona
and self.graph_config.inputs.persona.description.startswith(
"Hackathon Test"
)
):
_HACKATHON_TEST_EXECUTION = True
run_langgraph = run_hackathon_graph
else:
run_langgraph = run_basic_graph
stream = run_langgraph(
self.graph_config,
)
if _HACKATHON_TEST_EXECUTION:
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
if input_data.startswith("["):
input_type = "json"
input_list = json.loads(input_data)
else:
input_type = "list"
input_list = input_data.split(";")
num_examples_with_ground_truth = 0
total_score = 0.0
question = ""
for question_num, question_data in enumerate(input_list):
ground_truth_docs = None
if input_type == "json":
question = question_data["question"]
ground_truth = question_data.get("ground_truth")
if ground_truth:
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
logger.info(f"Question {question_num}: {question}")
_append_ground_truth_to_csv(question, ground_truth_docs)
else:
continue
else:
question = question_data
self.graph_config.inputs.prompt_builder.raw_user_query = question
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
HumanMessage(
content=question, additional_kwargs={}, response_metadata={}
),
2,
)
self.graph_config.tooling.force_use_tool.force_use = True
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
llm_answer_segments: list[str] = []
doc_results: list[str] | None = None
for answer_piece in processed_stream:
if isinstance(answer_piece, OnyxAnswerPiece):
llm_answer_segments.append(answer_piece.answer_piece or "")
elif isinstance(answer_piece, ToolCallFinalResult):
doc_results = [x.get("link") for x in answer_piece.tool_result]
if doc_results:
_append_search_results_to_csv(question, doc_results)
_append_answer_to_csv(question, "".join(llm_answer_segments))
if ground_truth_docs and doc_results:
num_examples_with_ground_truth += 1
doc_score = 0.0
for doc_id in ground_truth_docs:
doc_score += _get_doc_score(doc_id, doc_results)
_append_score_to_csv(question, doc_score)
total_score += doc_score
self._processed_stream = processed_stream
if num_examples_with_ground_truth > 0:
comprehensive_score = total_score / num_examples_with_ground_truth
else:
comprehensive_score = 0
_append_empty_line()
_append_score_to_csv(question, comprehensive_score)
else:
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
self._processed_stream = processed_stream
@property
def llm_answer(self) -> str:

View File

@@ -51,6 +51,7 @@ from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
@@ -79,6 +80,7 @@ from onyx.db.chat import reserve_message_id
from onyx.db.chat import translate_db_message_to_chat_message_detail
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
from onyx.db.chat import update_chat_session_updated_at_timestamp
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.milestone import check_multi_assistant_milestone
from onyx.db.milestone import create_milestone_if_not_exists
@@ -392,6 +394,16 @@ def _get_persona_for_chat_session(
return persona
def _get_connected_sources(
user: User | None,
db_session: Session,
) -> list[DocumentSource]:
cc_pairs = get_connector_credential_pairs_for_user(
db_session, user, eager_load_connector=True
)
return list(set(cc_pair.connector.source for cc_pair in cc_pairs))
ChatPacket = (
StreamingError
| QADocsResponse
@@ -969,6 +981,7 @@ def stream_chat_message_objects(
]
)
connected_sources = _get_connected_sources(user, db_session)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=final_msg.message,
@@ -976,7 +989,9 @@ def stream_chat_message_objects(
files=latest_query_files,
single_message_history=single_message_history,
),
system_message=default_build_system_message(prompt_config, llm.config),
system_message=default_build_system_message(
prompt_config, llm.config, connected_sources
),
message_history=message_history,
llm_config=llm.config,
raw_user_query=final_msg.message,

View File

@@ -10,6 +10,7 @@ from pydantic.v1 import BaseModel as BaseModel__v1
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
from onyx.configs.constants import DocumentSource
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.interfaces import LLMConfig
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
@@ -24,6 +25,7 @@ from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
from onyx.prompts.prompt_utils import drop_messages_history_overflow
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.prompt_utils import include_connected_sources
from onyx.tools.force import ForceUseTool
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
@@ -34,6 +36,7 @@ from onyx.tools.tool import Tool
def default_build_system_message(
prompt_config: PromptConfig,
llm_config: LLMConfig,
connected_sources: list[DocumentSource] | None = None,
) -> SystemMessage | None:
system_prompt = prompt_config.system_prompt.strip()
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
@@ -48,6 +51,9 @@ def default_build_system_message(
prompt_config,
add_additional_info_if_no_tag=prompt_config.datetime_aware,
)
tag_handled_prompt = include_connected_sources(
tag_handled_prompt, connected_sources
)
if not tag_handled_prompt:
return None
@@ -187,8 +193,12 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
if (
self.new_messages_and_token_cnts
and isinstance(self.user_message_and_token_cnt[0].content, str)
and self.user_message_and_token_cnt[0].content.startswith("Refer")
):
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -787,3 +787,7 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
# Forcing Vespa Language
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
)

View File

@@ -176,6 +176,7 @@ class SearchRequest(ChunkContext):
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
gen_excerpts: list[str] | None = None
class SearchQuery(ChunkContext):
@@ -205,6 +206,7 @@ class SearchQuery(ChunkContext):
precomputed_query_embedding: Embedding | None = None
expanded_queries: QueryExpansions | None = None
gen_excerpts: list[str] | None = None
class RetrievalDetails(ChunkContext):

View File

@@ -270,4 +270,5 @@ def retrieval_preprocessing(
full_doc=search_request.full_doc,
precomputed_query_embedding=search_request.precomputed_query_embedding,
expanded_queries=search_request.expanded_queries,
gen_excerpts=search_request.gen_excerpts,
)

View File

@@ -161,9 +161,9 @@ def doc_index_retrieval(
keyword_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
semantic_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
top_base_chunks_standard_ranking_thread: (
TimeoutThread[list[InferenceChunkUncleaned]] | None
) = None
top_base_chunks_standard_ranking_threads: list[
TimeoutThread[list[InferenceChunkUncleaned]]
] = []
top_semantic_chunks_thread: TimeoutThread[list[InferenceChunkUncleaned]] | None = (
None
@@ -174,19 +174,22 @@ def doc_index_retrieval(
top_semantic_chunks: list[InferenceChunkUncleaned] | None = None
# original retrieveal method
top_base_chunks_standard_ranking_thread = run_in_background(
document_index.hybrid_retrieval,
query.query,
query_embedding,
query.processed_keywords,
query.filters,
query.hybrid_alpha,
query.recency_bias_multiplier,
query.num_hits,
QueryExpansionType.SEMANTIC,
query.offset,
)
# original retrieval method
top_base_chunks_standard_ranking_threads = [
run_in_background(
document_index.hybrid_retrieval,
to_search,
query_embedding,
query.processed_keywords,
query.filters,
query.hybrid_alpha,
query.recency_bias_multiplier,
query.num_hits,
QueryExpansionType.SEMANTIC,
query.offset,
)
for to_search in [query.query] + (query.gen_excerpts or [])
]
if (
query.expanded_queries
@@ -245,9 +248,16 @@ def doc_index_retrieval(
query.offset,
)
top_base_chunks_standard_ranking = wait_on_background(
top_base_chunks_standard_ranking_thread
)
top_base_chunks_standard_ranking = [
chunk
for thread_results in zip(
*[
wait_on_background(thread)
for thread in top_base_chunks_standard_ranking_threads
]
)
for chunk in thread_results
]
top_keyword_chunks = wait_on_background(top_keyword_chunks_thread)
@@ -267,8 +277,12 @@ def doc_index_retrieval(
else:
top_base_chunks_standard_ranking = wait_on_background(
top_base_chunks_standard_ranking_thread
top_base_chunks_standard_ranking = sum(
[
wait_on_background(thread)
for thread in top_base_chunks_standard_ranking_threads
],
[],
)
top_chunks = _dedupe_chunks(top_base_chunks_standard_ranking)
@@ -297,6 +311,8 @@ def doc_index_retrieval(
else:
normal_chunks.append(chunk)
normal_chunks.sort(key=lambda x: x.score or 0, reverse=True)
# If there are no large chunks, just return the normal chunks
if not retrieval_requests:
return cleanup_chunks(normal_chunks)

View File

@@ -324,3 +324,22 @@ Respond with EXACTLY and ONLY one rephrased query.
Rephrased query for search engine:
""".strip()
GEN_EXCERPT_PROMPT = """
The following is the full prompt that a tool-calling LLM will use to produce queries against
a vectorstore.
<prompt>
{prompt}
</prompt>
Your job is to pick the three sources of information that you think are most likely to contain
relevant, authoritative information about the user query or search terms. For each chosen source,
generate a sentence or two that plausibly could exist in the real document(s) containing the desired information.
Try to imagine a variety of contexts in which the information could be found. For example,
information about a policy or service could exist in a google doc or a slack conversation,
but the language surrounding the information might be very different across sources.
Return the excerpts as a list of strings, i.e. ["excerpt1", "excerpt2", "excerpt3"]
""".strip()

View File

@@ -69,6 +69,19 @@ def handle_onyx_date_awareness(
return prompt_str
def include_connected_sources(
prompt_str: str,
connected_sources: list[DocumentSource] | None = None,
) -> str:
if connected_sources:
return (
prompt_str
+ "You are connected to the following sources: "
+ ", ".join([src.value for src in connected_sources])
)
return prompt_str
def build_task_prompt_reminders(
prompt: Prompt | PromptConfig,
use_language_hint: bool,

View File

@@ -83,6 +83,7 @@ class SearchToolOverrideKwargs(BaseModel):
kg_terms: list[str] | None = None
kg_sources: list[str] | None = None
kg_chunk_id_zero_only: bool | None = False
gen_excerpts: list[str] | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -1,7 +1,10 @@
import copy
import csv
import json
import os
from collections.abc import Callable
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import cast
from typing import TypeVar
@@ -19,6 +22,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prune_and_merge import prune_and_merge_sections
from onyx.chat.prune_and_merge import prune_sections
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -62,6 +66,39 @@ SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
def _append_ranking_stats_to_csv(
llm_doc_results: list[tuple[int, str]],
query: str,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for pos, doc in llm_doc_results:
writer.writerow([query, pos, doc, ""])
logger.debug(f"Appended {len(llm_doc_results)} ranking stats to {csv_path}")
class SearchResponseSummary(SearchQueryInfo):
top_sections: list[InferenceSection]
rephrased_query: str | None = None
@@ -302,6 +339,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
kg_terms = None
kg_sources = None
kg_chunk_id_zero_only = False
gen_excerpts = []
if override_kwargs:
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
@@ -319,6 +357,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
kg_terms = override_kwargs.kg_terms
kg_sources = override_kwargs.kg_sources
kg_chunk_id_zero_only = override_kwargs.kg_chunk_id_zero_only or False
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
precomputed_keywords = override_kwargs.precomputed_keywords
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
gen_excerpts = override_kwargs.gen_excerpts or []
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
@@ -395,6 +437,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
precomputed_keywords=precomputed_keywords,
# add expanded queries
expanded_queries=expanded_queries,
gen_excerpts=gen_excerpts,
),
user=self.user,
llm=self.llm,
@@ -499,6 +542,12 @@ def yield_search_responses(
)
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
# Append ranking statistics to a CSV file
llm_doc_results = []
for pos, doc in enumerate(llm_docs):
llm_doc_results.append((pos, doc.document_id))
# _append_ranking_stats_to_csv(llm_doc_results, query)
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)