forked from github/onyx
Compare commits
12 Commits
main
...
hackathon-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfcf39bff7 | ||
|
|
aa4a10e016 | ||
|
|
19292aff65 | ||
|
|
9a2f4e86dd | ||
|
|
6d835f7808 | ||
|
|
013bed3157 | ||
|
|
289f27c43a | ||
|
|
736a9bd332 | ||
|
|
8bcad415bb | ||
|
|
93e6e4a089 | ||
|
|
ed0062dce0 | ||
|
|
6e8bf3120c |
@@ -5,6 +5,7 @@ from pydantic import model_validator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.models import Persona
|
||||
from onyx.file_store.utils import InMemoryChatFile
|
||||
@@ -71,6 +72,7 @@ class GraphSearchConfig(BaseModel):
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
gen_excerpts: bool = True
|
||||
kg_config_settings: KGConfigSettings = KGConfigSettings()
|
||||
|
||||
|
||||
@@ -93,3 +95,10 @@ class GraphConfig(BaseModel):
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class GeneratedExcerpt(BaseModel):
|
||||
"""A generated excerpt from a document"""
|
||||
|
||||
excerpt: str
|
||||
document_source: DocumentSource
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -22,6 +23,9 @@ from onyx.configs.chat_configs import USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEAR
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.chat_prompts import GEN_EXCERPT_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
from onyx.prompts.chat_prompts import QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
@@ -37,6 +41,7 @@ from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -51,9 +56,21 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||
else:
|
||||
continue
|
||||
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||
|
||||
return "\n".join(history_segments)
|
||||
|
||||
|
||||
HISTORY_PROMPT_MAP = {
|
||||
(QueryExpansionType.KEYWORD, True): QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT,
|
||||
(QueryExpansionType.KEYWORD, False): QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT,
|
||||
(QueryExpansionType.SEMANTIC, True): QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT,
|
||||
(
|
||||
QueryExpansionType.SEMANTIC,
|
||||
False,
|
||||
): QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT,
|
||||
}
|
||||
|
||||
|
||||
def _expand_query(
|
||||
query: str,
|
||||
expansion_type: QueryExpansionType,
|
||||
@@ -62,18 +79,11 @@ def _expand_query(
|
||||
|
||||
history_str = _create_history_str(prompt_builder)
|
||||
|
||||
base_prompt = HISTORY_PROMPT_MAP[(expansion_type, bool(history_str))]
|
||||
format_args = {"question": query}
|
||||
if history_str:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITH_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITH_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query, history=history_str)
|
||||
else:
|
||||
if expansion_type == QueryExpansionType.KEYWORD:
|
||||
base_prompt = QUERY_KEYWORD_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
else:
|
||||
base_prompt = QUERY_SEMANTIC_EXPANSION_WITHOUT_HISTORY_PROMPT
|
||||
expansion_prompt = base_prompt.format(question=query)
|
||||
format_args["history"] = history_str
|
||||
expansion_prompt = base_prompt.format(**format_args)
|
||||
|
||||
msg = HumanMessage(content=expansion_prompt)
|
||||
primary_llm, _ = get_default_llms()
|
||||
@@ -83,15 +93,34 @@ def _expand_query(
|
||||
return rephrased_query
|
||||
|
||||
|
||||
# TODO: right now the llm describes the places where stuff could be, I want it to hallucinate real content
|
||||
# fine for now since I'm just tryna get the pipeline going
|
||||
def _gen_excerpts(
|
||||
prompt_builder: AnswerPromptBuilder,
|
||||
llm: LLM,
|
||||
) -> list[str]:
|
||||
user_prompt = prompt_builder.build()
|
||||
prompt_str = GEN_EXCERPT_PROMPT.format(prompt=user_prompt)
|
||||
excerpts_str = message_to_string(llm.invoke(prompt_str))
|
||||
try:
|
||||
excerpts = json.loads(excerpts_str)
|
||||
except json.JSONDecodeError:
|
||||
excerpts = [
|
||||
exc for exc in excerpts_str.split('"') if len(exc) > 5
|
||||
] # TODO: jank
|
||||
|
||||
if not isinstance(excerpts, list):
|
||||
raise ValueError(f"Excerpts is not a list: {excerpts}")
|
||||
|
||||
return excerpts
|
||||
|
||||
|
||||
def _expand_query_non_tool_calling_llm(
|
||||
expanded_keyword_thread: TimeoutThread[str],
|
||||
expanded_semantic_thread: TimeoutThread[str],
|
||||
) -> QueryExpansions | None:
|
||||
keyword_expansion: str | None = wait_on_background(expanded_keyword_thread)
|
||||
semantic_expansion: str | None = wait_on_background(expanded_semantic_thread)
|
||||
|
||||
if keyword_expansion is None or semantic_expansion is None:
|
||||
return None
|
||||
keyword_expansion: str = wait_on_background(expanded_keyword_thread)
|
||||
semantic_expansion: str = wait_on_background(expanded_semantic_thread)
|
||||
|
||||
return QueryExpansions(
|
||||
keywords_expansions=[keyword_expansion],
|
||||
@@ -123,6 +152,7 @@ def choose_tool(
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
expanded_keyword_thread: TimeoutThread[str] | None = None
|
||||
expanded_semantic_thread: TimeoutThread[str] | None = None
|
||||
gen_excerpts_thread: TimeoutThread[list[str]] | None = None
|
||||
# If we have override_kwargs, add them to the tool_args
|
||||
override_kwargs: SearchToolOverrideKwargs = (
|
||||
force_use_tool.override_kwargs or SearchToolOverrideKwargs()
|
||||
@@ -166,6 +196,12 @@ def choose_tool(
|
||||
QueryExpansionType.SEMANTIC,
|
||||
prompt_builder,
|
||||
)
|
||||
if agent_config.behavior.gen_excerpts:
|
||||
gen_excerpts_thread = run_in_background(
|
||||
_gen_excerpts,
|
||||
agent_config.inputs.prompt_builder,
|
||||
llm,
|
||||
)
|
||||
|
||||
structured_response_format = agent_config.inputs.structured_response_format
|
||||
tools = [
|
||||
@@ -224,6 +260,10 @@ def choose_tool(
|
||||
):
|
||||
raise ValueError("No expanded keyword or semantic threads found.")
|
||||
|
||||
if gen_excerpts_thread:
|
||||
excerpts = wait_on_background(gen_excerpts_thread)
|
||||
override_kwargs.gen_excerpts = excerpts
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
@@ -339,6 +379,10 @@ def choose_tool(
|
||||
):
|
||||
raise ValueError("No expanded keyword or semantic threads found.")
|
||||
|
||||
if gen_excerpts_thread:
|
||||
excerpts = wait_on_background(gen_excerpts_thread)
|
||||
override_kwargs.gen_excerpts = excerpts
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
|
||||
@@ -460,14 +460,14 @@ def should_index(
|
||||
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}"
|
||||
# )
|
||||
return False
|
||||
else:
|
||||
if (
|
||||
connector.id == 0 or connector.source == DocumentSource.INGESTION_API
|
||||
): # Ingestion API
|
||||
# print(
|
||||
# f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
|
||||
# )
|
||||
return False
|
||||
# else:
|
||||
# if (
|
||||
# connector.source == DocumentSource.INGESTION_API
|
||||
# ): # Ingestion API
|
||||
# # print(
|
||||
# # f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source"
|
||||
# # )
|
||||
# return False
|
||||
return True
|
||||
|
||||
# If the connector is paused or is the ingestion API, don't index
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
@@ -11,6 +16,9 @@ from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_agent_search_graph
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import (
|
||||
run_basic_graph as run_hackathon_graph,
|
||||
) # You can create your own graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -22,9 +30,11 @@ from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import ToolCallFinalResult
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
|
||||
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
@@ -44,6 +54,190 @@ logger = setup_logger()
|
||||
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
|
||||
|
||||
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
|
||||
"""
|
||||
Calculate the score for a given position.
|
||||
"""
|
||||
if pos > max_acceptable_pos:
|
||||
return 0
|
||||
|
||||
elif pos == 1:
|
||||
return 1
|
||||
elif pos == 2:
|
||||
return 0.8
|
||||
else:
|
||||
return 4 / (pos + 5)
|
||||
|
||||
|
||||
def _clean_doc_id_link(doc_link: str) -> str:
|
||||
"""
|
||||
Clean the google doc link.
|
||||
"""
|
||||
if "google.com" in doc_link:
|
||||
if "/edit" in doc_link:
|
||||
return "/edit".join(doc_link.split("/edit")[:-1])
|
||||
elif "/view" in doc_link:
|
||||
return "/view".join(doc_link.split("/view")[:-1])
|
||||
else:
|
||||
return doc_link
|
||||
|
||||
if "app.fireflies.ai" in doc_link:
|
||||
return "?".join(doc_link.split("?")[:-1])
|
||||
return doc_link
|
||||
|
||||
|
||||
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
|
||||
"""
|
||||
Get the score of a document from the document results.
|
||||
"""
|
||||
|
||||
match_pos = None
|
||||
for pos, comp_doc in enumerate(doc_results, start=1):
|
||||
|
||||
clear_doc_id = _clean_doc_id_link(doc_id)
|
||||
clear_comp_doc = _clean_doc_id_link(comp_doc)
|
||||
|
||||
if clear_doc_id == clear_comp_doc:
|
||||
match_pos = pos
|
||||
|
||||
if match_pos is None:
|
||||
return 0.0
|
||||
|
||||
return _calc_score_for_pos(match_pos)
|
||||
|
||||
|
||||
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH) -> None:
|
||||
"""
|
||||
Append an empty line to the CSV file.
|
||||
"""
|
||||
_append_answer_to_csv("", "", csv_path)
|
||||
|
||||
|
||||
def _append_ground_truth_to_csv(
|
||||
query: str,
|
||||
ground_truth_docs: list[str],
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the score to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
for doc_id in ground_truth_docs:
|
||||
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
|
||||
|
||||
logger.debug("Appended score to csv file")
|
||||
|
||||
|
||||
def _append_score_to_csv(
|
||||
query: str,
|
||||
score: float,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the score to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
writer.writerow([query, "", "", "", score])
|
||||
|
||||
logger.debug("Appended score to csv file")
|
||||
|
||||
|
||||
def _append_search_results_to_csv(
|
||||
query: str,
|
||||
doc_results: list[str],
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the search results to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
for pos, doc in enumerate(doc_results, start=1):
|
||||
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
|
||||
|
||||
logger.debug("Appended search results to csv file")
|
||||
|
||||
|
||||
def _append_answer_to_csv(
|
||||
query: str,
|
||||
answer: str,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append ranking statistics to a CSV file.
|
||||
|
||||
Args:
|
||||
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||
csv_path: Path to the CSV file to append to
|
||||
"""
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
writer.writerow([query, "", "", answer, ""])
|
||||
|
||||
logger.debug("Appended answer to csv file")
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -134,6 +328,9 @@ class Answer:
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
|
||||
_HACKATHON_TEST_EXECUTION = False
|
||||
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
@@ -154,22 +351,118 @@ class Answer:
|
||||
)
|
||||
):
|
||||
run_langgraph = run_dc_graph
|
||||
|
||||
elif (
|
||||
self.graph_config.inputs.persona
|
||||
and self.graph_config.inputs.persona.description.startswith(
|
||||
"Hackathon Test"
|
||||
)
|
||||
):
|
||||
_HACKATHON_TEST_EXECUTION = True
|
||||
run_langgraph = run_hackathon_graph
|
||||
|
||||
else:
|
||||
run_langgraph = run_basic_graph
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
if _HACKATHON_TEST_EXECUTION:
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
|
||||
|
||||
if input_data.startswith("["):
|
||||
input_type = "json"
|
||||
input_list = json.loads(input_data)
|
||||
else:
|
||||
input_type = "list"
|
||||
input_list = input_data.split(";")
|
||||
|
||||
num_examples_with_ground_truth = 0
|
||||
total_score = 0.0
|
||||
|
||||
question = ""
|
||||
for question_num, question_data in enumerate(input_list):
|
||||
|
||||
ground_truth_docs = None
|
||||
if input_type == "json":
|
||||
question = question_data["question"]
|
||||
ground_truth = question_data.get("ground_truth")
|
||||
if ground_truth:
|
||||
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
|
||||
logger.info(f"Question {question_num}: {question}")
|
||||
_append_ground_truth_to_csv(question, ground_truth_docs)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
question = question_data
|
||||
|
||||
self.graph_config.inputs.prompt_builder.raw_user_query = question
|
||||
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
|
||||
HumanMessage(
|
||||
content=question, additional_kwargs={}, response_metadata={}
|
||||
),
|
||||
2,
|
||||
)
|
||||
self.graph_config.tooling.force_use_tool.force_use = True
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
|
||||
llm_answer_segments: list[str] = []
|
||||
doc_results: list[str] | None = None
|
||||
for answer_piece in processed_stream:
|
||||
if isinstance(answer_piece, OnyxAnswerPiece):
|
||||
llm_answer_segments.append(answer_piece.answer_piece or "")
|
||||
elif isinstance(answer_piece, ToolCallFinalResult):
|
||||
doc_results = [x.get("link") for x in answer_piece.tool_result]
|
||||
|
||||
if doc_results:
|
||||
_append_search_results_to_csv(question, doc_results)
|
||||
|
||||
_append_answer_to_csv(question, "".join(llm_answer_segments))
|
||||
|
||||
if ground_truth_docs and doc_results:
|
||||
num_examples_with_ground_truth += 1
|
||||
doc_score = 0.0
|
||||
for doc_id in ground_truth_docs:
|
||||
doc_score += _get_doc_score(doc_id, doc_results)
|
||||
|
||||
_append_score_to_csv(question, doc_score)
|
||||
total_score += doc_score
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
if num_examples_with_ground_truth > 0:
|
||||
comprehensive_score = total_score / num_examples_with_ground_truth
|
||||
else:
|
||||
comprehensive_score = 0
|
||||
|
||||
_append_empty_line()
|
||||
|
||||
_append_score_to_csv(question, comprehensive_score)
|
||||
|
||||
else:
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
|
||||
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
@@ -79,6 +80,7 @@ from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from onyx.db.chat import update_chat_session_updated_at_timestamp
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.milestone import check_multi_assistant_milestone
|
||||
from onyx.db.milestone import create_milestone_if_not_exists
|
||||
@@ -392,6 +394,16 @@ def _get_persona_for_chat_session(
|
||||
return persona
|
||||
|
||||
|
||||
def _get_connected_sources(
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> list[DocumentSource]:
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
db_session, user, eager_load_connector=True
|
||||
)
|
||||
return list(set(cc_pair.connector.source for cc_pair in cc_pairs))
|
||||
|
||||
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
@@ -969,6 +981,7 @@ def stream_chat_message_objects(
|
||||
]
|
||||
)
|
||||
|
||||
connected_sources = _get_connected_sources(user, db_session)
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
@@ -976,7 +989,9 @@ def stream_chat_message_objects(
|
||||
files=latest_query_files,
|
||||
single_message_history=single_message_history,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
system_message=default_build_system_message(
|
||||
prompt_config, llm.config, connected_sources
|
||||
),
|
||||
message_history=message_history,
|
||||
llm_config=llm.config,
|
||||
raw_user_query=final_msg.message,
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from onyx.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
|
||||
@@ -24,6 +25,7 @@ from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
from onyx.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import include_connected_sources
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
@@ -34,6 +36,7 @@ from onyx.tools.tool import Tool
|
||||
def default_build_system_message(
|
||||
prompt_config: PromptConfig,
|
||||
llm_config: LLMConfig,
|
||||
connected_sources: list[DocumentSource] | None = None,
|
||||
) -> SystemMessage | None:
|
||||
system_prompt = prompt_config.system_prompt.strip()
|
||||
# See https://simonwillison.net/tags/markdown/ for context on this temporary fix
|
||||
@@ -48,6 +51,9 @@ def default_build_system_message(
|
||||
prompt_config,
|
||||
add_additional_info_if_no_tag=prompt_config.datetime_aware,
|
||||
)
|
||||
tag_handled_prompt = include_connected_sources(
|
||||
tag_handled_prompt, connected_sources
|
||||
)
|
||||
|
||||
if not tag_handled_prompt:
|
||||
return None
|
||||
@@ -187,8 +193,12 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
if (
|
||||
self.new_messages_and_token_cnts
|
||||
and isinstance(self.user_message_and_token_cnt[0].content, str)
|
||||
and self.user_message_and_token_cnt[0].content.startswith("Refer")
|
||||
):
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -787,3 +787,7 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
|
||||
# Forcing Vespa Language
|
||||
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
|
||||
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
|
||||
|
||||
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
|
||||
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
|
||||
)
|
||||
|
||||
@@ -176,6 +176,7 @@ class SearchRequest(ChunkContext):
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
gen_excerpts: list[str] | None = None
|
||||
|
||||
|
||||
class SearchQuery(ChunkContext):
|
||||
@@ -205,6 +206,7 @@ class SearchQuery(ChunkContext):
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
gen_excerpts: list[str] | None = None
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
|
||||
@@ -270,4 +270,5 @@ def retrieval_preprocessing(
|
||||
full_doc=search_request.full_doc,
|
||||
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
||||
expanded_queries=search_request.expanded_queries,
|
||||
gen_excerpts=search_request.gen_excerpts,
|
||||
)
|
||||
|
||||
@@ -161,9 +161,9 @@ def doc_index_retrieval(
|
||||
|
||||
keyword_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
|
||||
semantic_embeddings_thread: TimeoutThread[list[Embedding]] | None = None
|
||||
top_base_chunks_standard_ranking_thread: (
|
||||
TimeoutThread[list[InferenceChunkUncleaned]] | None
|
||||
) = None
|
||||
top_base_chunks_standard_ranking_threads: list[
|
||||
TimeoutThread[list[InferenceChunkUncleaned]]
|
||||
] = []
|
||||
|
||||
top_semantic_chunks_thread: TimeoutThread[list[InferenceChunkUncleaned]] | None = (
|
||||
None
|
||||
@@ -174,19 +174,22 @@ def doc_index_retrieval(
|
||||
|
||||
top_semantic_chunks: list[InferenceChunkUncleaned] | None = None
|
||||
|
||||
# original retrieveal method
|
||||
top_base_chunks_standard_ranking_thread = run_in_background(
|
||||
document_index.hybrid_retrieval,
|
||||
query.query,
|
||||
query_embedding,
|
||||
query.processed_keywords,
|
||||
query.filters,
|
||||
query.hybrid_alpha,
|
||||
query.recency_bias_multiplier,
|
||||
query.num_hits,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
query.offset,
|
||||
)
|
||||
# original retrieval method
|
||||
top_base_chunks_standard_ranking_threads = [
|
||||
run_in_background(
|
||||
document_index.hybrid_retrieval,
|
||||
to_search,
|
||||
query_embedding,
|
||||
query.processed_keywords,
|
||||
query.filters,
|
||||
query.hybrid_alpha,
|
||||
query.recency_bias_multiplier,
|
||||
query.num_hits,
|
||||
QueryExpansionType.SEMANTIC,
|
||||
query.offset,
|
||||
)
|
||||
for to_search in [query.query] + (query.gen_excerpts or [])
|
||||
]
|
||||
|
||||
if (
|
||||
query.expanded_queries
|
||||
@@ -245,9 +248,16 @@ def doc_index_retrieval(
|
||||
query.offset,
|
||||
)
|
||||
|
||||
top_base_chunks_standard_ranking = wait_on_background(
|
||||
top_base_chunks_standard_ranking_thread
|
||||
)
|
||||
top_base_chunks_standard_ranking = [
|
||||
chunk
|
||||
for thread_results in zip(
|
||||
*[
|
||||
wait_on_background(thread)
|
||||
for thread in top_base_chunks_standard_ranking_threads
|
||||
]
|
||||
)
|
||||
for chunk in thread_results
|
||||
]
|
||||
|
||||
top_keyword_chunks = wait_on_background(top_keyword_chunks_thread)
|
||||
|
||||
@@ -267,8 +277,12 @@ def doc_index_retrieval(
|
||||
|
||||
else:
|
||||
|
||||
top_base_chunks_standard_ranking = wait_on_background(
|
||||
top_base_chunks_standard_ranking_thread
|
||||
top_base_chunks_standard_ranking = sum(
|
||||
[
|
||||
wait_on_background(thread)
|
||||
for thread in top_base_chunks_standard_ranking_threads
|
||||
],
|
||||
[],
|
||||
)
|
||||
|
||||
top_chunks = _dedupe_chunks(top_base_chunks_standard_ranking)
|
||||
@@ -297,6 +311,8 @@ def doc_index_retrieval(
|
||||
else:
|
||||
normal_chunks.append(chunk)
|
||||
|
||||
normal_chunks.sort(key=lambda x: x.score or 0, reverse=True)
|
||||
|
||||
# If there are no large chunks, just return the normal chunks
|
||||
if not retrieval_requests:
|
||||
return cleanup_chunks(normal_chunks)
|
||||
|
||||
@@ -324,3 +324,22 @@ Respond with EXACTLY and ONLY one rephrased query.
|
||||
|
||||
Rephrased query for search engine:
|
||||
""".strip()
|
||||
|
||||
|
||||
GEN_EXCERPT_PROMPT = """
|
||||
The following is the full prompt that a tool-calling LLM will use to produce queries against
|
||||
a vectorstore.
|
||||
|
||||
<prompt>
|
||||
{prompt}
|
||||
</prompt>
|
||||
|
||||
Your job is to pick the three sources of information that you think are most likely to contain
|
||||
relevant, authoritative information about the user query or search terms. For each chosen source,
|
||||
generate a sentence or two that plausibly could exist in the real document(s) containing the desired information.
|
||||
Try to imagine a variety of contexts in which the information could be found. For example,
|
||||
information about a policy or service could exist in a google doc or a slack conversation,
|
||||
but the language surrounding the information might be very different across sources.
|
||||
|
||||
Return the excerpts as a list of strings, i.e. ["excerpt1", "excerpt2", "excerpt3"]
|
||||
""".strip()
|
||||
|
||||
@@ -69,6 +69,19 @@ def handle_onyx_date_awareness(
|
||||
return prompt_str
|
||||
|
||||
|
||||
def include_connected_sources(
|
||||
prompt_str: str,
|
||||
connected_sources: list[DocumentSource] | None = None,
|
||||
) -> str:
|
||||
if connected_sources:
|
||||
return (
|
||||
prompt_str
|
||||
+ "You are connected to the following sources: "
|
||||
+ ", ".join([src.value for src in connected_sources])
|
||||
)
|
||||
return prompt_str
|
||||
|
||||
|
||||
def build_task_prompt_reminders(
|
||||
prompt: Prompt | PromptConfig,
|
||||
use_language_hint: bool,
|
||||
|
||||
@@ -83,6 +83,7 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
kg_terms: list[str] | None = None
|
||||
kg_sources: list[str] | None = None
|
||||
kg_chunk_id_zero_only: bool | None = False
|
||||
gen_excerpts: list[str] | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import copy
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
@@ -19,6 +22,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from onyx.chat.prune_and_merge import prune_and_merge_sections
|
||||
from onyx.chat.prune_and_merge import prune_sections
|
||||
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -62,6 +66,39 @@ SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
QUERY_FIELD = "query"
|
||||
|
||||
|
||||
def _append_ranking_stats_to_csv(
|
||||
llm_doc_results: list[tuple[int, str]],
|
||||
query: str,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append ranking statistics to a CSV file.
|
||||
|
||||
Args:
|
||||
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||
csv_path: Path to the CSV file to append to
|
||||
"""
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
for pos, doc in llm_doc_results:
|
||||
writer.writerow([query, pos, doc, ""])
|
||||
|
||||
logger.debug(f"Appended {len(llm_doc_results)} ranking stats to {csv_path}")
|
||||
|
||||
|
||||
class SearchResponseSummary(SearchQueryInfo):
|
||||
top_sections: list[InferenceSection]
|
||||
rephrased_query: str | None = None
|
||||
@@ -302,6 +339,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
kg_terms = None
|
||||
kg_sources = None
|
||||
kg_chunk_id_zero_only = False
|
||||
gen_excerpts = []
|
||||
if override_kwargs:
|
||||
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
@@ -319,6 +357,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
kg_terms = override_kwargs.kg_terms
|
||||
kg_sources = override_kwargs.kg_sources
|
||||
kg_chunk_id_zero_only = override_kwargs.kg_chunk_id_zero_only or False
|
||||
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
|
||||
precomputed_keywords = override_kwargs.precomputed_keywords
|
||||
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
|
||||
gen_excerpts = override_kwargs.gen_excerpts or []
|
||||
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
@@ -395,6 +437,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
precomputed_keywords=precomputed_keywords,
|
||||
# add expanded queries
|
||||
expanded_queries=expanded_queries,
|
||||
gen_excerpts=gen_excerpts,
|
||||
),
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
@@ -499,6 +542,12 @@ def yield_search_responses(
|
||||
)
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
|
||||
|
||||
# Append ranking statistics to a CSV file
|
||||
llm_doc_results = []
|
||||
for pos, doc in enumerate(llm_docs):
|
||||
llm_doc_results.append((pos, doc.document_id))
|
||||
# _append_ranking_stats_to_csv(llm_doc_results, query)
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user