mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-10 10:12:40 +00:00
Compare commits
26 Commits
v3.0.0
...
ronnie/hac
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1f91fd51ee | ||
|
|
3accbd2c11 | ||
|
|
4a8348a169 | ||
|
|
bdadef00cd | ||
|
|
318b5c0c6a | ||
|
|
6cfdedb9d9 | ||
|
|
7782a735c2 | ||
|
|
060f097737 | ||
|
|
a294f4464c | ||
|
|
7f6eb68295 | ||
|
|
4165150b20 | ||
|
|
4f9968727c | ||
|
|
833bac077b | ||
|
|
f06ae0a340 | ||
|
|
b082e845fd | ||
|
|
81148e10e2 | ||
|
|
8b5ff48271 | ||
|
|
013bed3157 | ||
|
|
289f27c43a | ||
|
|
6a78343bf3 | ||
|
|
655a5f3d56 | ||
|
|
736a9bd332 | ||
|
|
8bcad415bb | ||
|
|
93e6e4a089 | ||
|
|
ed0062dce0 | ||
|
|
6e8bf3120c |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -26,3 +26,6 @@ jira_test_env
|
||||
/deployment/data/nginx/app.conf
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
|
||||
# hackathon
|
||||
hackathon
|
||||
|
||||
@@ -82,13 +82,16 @@ if __name__ == "__main__":
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=0)
|
||||
|
||||
graph = basic_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
input = BasicInput(unused=True)
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config, _ = get_test_config(
|
||||
config = get_test_config(
|
||||
db_session=db_session,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
|
||||
@@ -51,6 +51,7 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||
else:
|
||||
continue
|
||||
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||
|
||||
return "\n".join(history_segments)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
@@ -11,6 +16,9 @@ from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_agent_search_graph
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import (
|
||||
run_basic_graph as run_hackathon_graph,
|
||||
) # You can create your own graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -22,9 +30,11 @@ from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import ToolCallFinalResult
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
|
||||
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
@@ -44,6 +54,191 @@ logger = setup_logger()
|
||||
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
|
||||
|
||||
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
|
||||
"""
|
||||
Calculate the score for a given position.
|
||||
"""
|
||||
if pos > max_acceptable_pos:
|
||||
return 0
|
||||
|
||||
elif pos == 1:
|
||||
return 1
|
||||
elif pos == 2:
|
||||
return 0.8
|
||||
else:
|
||||
return 4 / (pos + 5)
|
||||
|
||||
|
||||
def _clean_doc_id_link(doc_link: str) -> str:
|
||||
"""
|
||||
Clean the google doc link.
|
||||
"""
|
||||
if "google.com" in doc_link:
|
||||
if "/edit" in doc_link:
|
||||
return "/edit".join(doc_link.split("/edit")[:-1])
|
||||
elif "/view" in doc_link:
|
||||
return "/view".join(doc_link.split("/view")[:-1])
|
||||
else:
|
||||
return doc_link
|
||||
|
||||
if "app.fireflies.ai" in doc_link:
|
||||
return "?".join(doc_link.split("?")[:-1])
|
||||
return doc_link
|
||||
|
||||
|
||||
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
|
||||
"""
|
||||
Get the score of a document from the document results.
|
||||
"""
|
||||
|
||||
match_pos = None
|
||||
for pos, comp_doc in enumerate(doc_results, start=1):
|
||||
|
||||
clear_doc_id = _clean_doc_id_link(doc_id)
|
||||
clear_comp_doc = _clean_doc_id_link(comp_doc)
|
||||
|
||||
if clear_doc_id == clear_comp_doc:
|
||||
match_pos = pos
|
||||
|
||||
if match_pos is None:
|
||||
return 0.0
|
||||
|
||||
return _calc_score_for_pos(match_pos)
|
||||
|
||||
|
||||
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH) -> None:
|
||||
"""
|
||||
Append an empty line to the CSV file.
|
||||
"""
|
||||
_append_answer_to_csv("", "", csv_path)
|
||||
|
||||
|
||||
def _append_ground_truth_to_csv(
|
||||
query: str,
|
||||
ground_truth_docs: list[str],
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the score to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
for doc_id in ground_truth_docs:
|
||||
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
|
||||
|
||||
logger.debug("Appended score to csv file")
|
||||
|
||||
|
||||
def _append_score_to_csv(
|
||||
query: str,
|
||||
score: float,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the score to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
writer.writerow([query, "", "", "", score])
|
||||
|
||||
logger.debug("Appended score to csv file")
|
||||
|
||||
|
||||
def _append_search_results_to_csv(
|
||||
query: str,
|
||||
doc_results: list[str],
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the search results to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
for pos, doc in enumerate(doc_results, start=1):
|
||||
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
|
||||
|
||||
logger.debug("Appended search results to csv file")
|
||||
|
||||
|
||||
def _append_answer_to_csv(
|
||||
query: str,
|
||||
answer: str,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append ranking statistics to a CSV file.
|
||||
|
||||
Args:
|
||||
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||
csv_path: Path to the CSV file to append to
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
writer.writerow([query, "", "", answer, ""])
|
||||
|
||||
logger.debug("Appended answer to csv file")
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -134,6 +329,8 @@ class Answer:
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
_HACKATHON_TEST_EXECUTION = False
|
||||
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
@@ -154,22 +351,120 @@ class Answer:
|
||||
)
|
||||
):
|
||||
run_langgraph = run_dc_graph
|
||||
|
||||
elif (
|
||||
self.graph_config.inputs.persona
|
||||
and self.graph_config.inputs.persona.description.startswith(
|
||||
"Hackathon Test"
|
||||
)
|
||||
):
|
||||
_HACKATHON_TEST_EXECUTION = True
|
||||
run_langgraph = run_hackathon_graph
|
||||
|
||||
else:
|
||||
run_langgraph = run_basic_graph
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
if _HACKATHON_TEST_EXECUTION:
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
|
||||
|
||||
if input_data.startswith("["):
|
||||
input_type = "json"
|
||||
input_list = json.loads(input_data)
|
||||
else:
|
||||
input_type = "list"
|
||||
input_list = input_data.split(";")
|
||||
|
||||
num_examples_with_ground_truth = 0
|
||||
total_score = 0.0
|
||||
question = None
|
||||
|
||||
for question_num, question_data in enumerate(input_list):
|
||||
|
||||
ground_truth_docs = None
|
||||
if input_type == "json":
|
||||
question = question_data["question"]
|
||||
ground_truth = question_data.get("ground_truth")
|
||||
if ground_truth:
|
||||
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
|
||||
logger.info(f"Question {question_num}: {question}")
|
||||
_append_ground_truth_to_csv(question, ground_truth_docs)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
question = question_data
|
||||
|
||||
self.graph_config.inputs.prompt_builder.raw_user_query = question
|
||||
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
|
||||
HumanMessage(
|
||||
content=question, additional_kwargs={}, response_metadata={}
|
||||
),
|
||||
2,
|
||||
)
|
||||
self.graph_config.tooling.force_use_tool.force_use = True
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
|
||||
llm_answer_segments: list[str] = []
|
||||
doc_results: list[str] | None = None
|
||||
for answer_piece in processed_stream:
|
||||
if isinstance(answer_piece, OnyxAnswerPiece):
|
||||
llm_answer_segments.append(answer_piece.answer_piece or "")
|
||||
elif isinstance(answer_piece, ToolCallFinalResult):
|
||||
doc_results = [x.get("link") for x in answer_piece.tool_result]
|
||||
|
||||
if doc_results:
|
||||
_append_search_results_to_csv(question, doc_results)
|
||||
|
||||
_append_answer_to_csv(question, "".join(llm_answer_segments))
|
||||
|
||||
if ground_truth_docs and doc_results:
|
||||
num_examples_with_ground_truth += 1
|
||||
doc_score = 0.0
|
||||
for doc_id in ground_truth_docs:
|
||||
doc_score += _get_doc_score(doc_id, doc_results)
|
||||
|
||||
_append_score_to_csv(question, doc_score)
|
||||
total_score += doc_score
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
if num_examples_with_ground_truth > 0:
|
||||
comprehensive_score = total_score / num_examples_with_ground_truth
|
||||
else:
|
||||
comprehensive_score = 0
|
||||
|
||||
_append_empty_line()
|
||||
|
||||
if not question:
|
||||
raise RuntimeError(f"No questions in input list; {input_list=}")
|
||||
_append_score_to_csv(question, comprehensive_score)
|
||||
|
||||
else:
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
|
||||
@@ -1012,6 +1012,7 @@ def stream_chat_message_objects(
|
||||
tools=tools,
|
||||
db_session=db_session,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
|
||||
@@ -187,8 +187,12 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
if (
|
||||
self.new_messages_and_token_cnts
|
||||
and isinstance(self.user_message_and_token_cnt[0].content, str)
|
||||
and self.user_message_and_token_cnt[0].content.startswith("Refer")
|
||||
):
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -787,3 +787,7 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
|
||||
# Forcing Vespa Language
|
||||
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
|
||||
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
|
||||
|
||||
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
|
||||
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
|
||||
)
|
||||
|
||||
@@ -483,3 +483,24 @@ def section_relevance_list_impl(
|
||||
items=final_context_sections,
|
||||
)
|
||||
return [ind in llm_indices for ind in range(len(final_context_sections))]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine.sql_engine import SqlEngine, get_session_with_current_tenant
|
||||
from onyx.llm.factory import get_default_llms
|
||||
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=10)
|
||||
|
||||
llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sp = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query="What is Onyx?",
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
skip_query_analysis=True,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import string
|
||||
from collections.abc import Callable
|
||||
from enum import Enum
|
||||
|
||||
import nltk # type:ignore
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -20,6 +21,7 @@ from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA_KEYWORD
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.db.search_settings import get_simple_expansion_settings
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.vespa.shared_utils.utils import (
|
||||
@@ -27,6 +29,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
|
||||
)
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
|
||||
from onyx.secondary_llm_flows.query_expansion import simple_unilingual_query_expansion
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
@@ -41,6 +44,33 @@ from shared_configs.model_server_models import Embedding
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class ExpansionType(Enum):
|
||||
Off = "off"
|
||||
Simple = "simple"
|
||||
Multilingual = "multilingual"
|
||||
|
||||
|
||||
def _get_expansion_type(
|
||||
db_session: Session,
|
||||
query: str,
|
||||
) -> tuple[ExpansionType, list[str], int]:
|
||||
if "\n" in query or "\r" in query:
|
||||
return ExpansionType.Off, [], 0
|
||||
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
if multilingual_expansion:
|
||||
return ExpansionType.Multilingual, multilingual_expansion, 0
|
||||
|
||||
enable_simple_expansion, simple_expansion_count = get_simple_expansion_settings(
|
||||
db_session
|
||||
)
|
||||
|
||||
if enable_simple_expansion:
|
||||
return ExpansionType.Simple, [], simple_expansion_count
|
||||
|
||||
return ExpansionType.Off, [], 0
|
||||
|
||||
|
||||
def _dedupe_chunks(
|
||||
chunks: list[InferenceChunkUncleaned],
|
||||
) -> list[InferenceChunkUncleaned]:
|
||||
@@ -348,6 +378,92 @@ def _simplify_text(text: str) -> str:
|
||||
).lower()
|
||||
|
||||
|
||||
def _retrieve_search_with_simple_expansion(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
query: SearchQuery,
|
||||
simple_expansion_count: int,
|
||||
) -> list[InferenceChunk]:
|
||||
# Use simple unilingual expansion for basic cases
|
||||
simplified_queries_simple = set()
|
||||
run_queries_simple: list[tuple[Callable, tuple]] = []
|
||||
|
||||
query_rephrases_simple = simple_unilingual_query_expansion(
|
||||
query=query.query,
|
||||
num_expansions=simple_expansion_count,
|
||||
)
|
||||
# Just to be extra sure, add the original query.
|
||||
query_rephrases_simple.append(query.query)
|
||||
for rephrase in set(query_rephrases_simple):
|
||||
# Sometimes the model rephrases the query in the same language with minor changes
|
||||
# Avoid doing an extra search with the minor changes as this biases the results
|
||||
simplified_rephrase = _simplify_text(rephrase)
|
||||
if simplified_rephrase in simplified_queries_simple:
|
||||
continue
|
||||
simplified_queries_simple.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.model_copy(
|
||||
update={
|
||||
"query": rephrase,
|
||||
# need to recompute for each rephrase
|
||||
# note that `SearchQuery` is a frozen model, so we can't update
|
||||
# it below
|
||||
"precomputed_query_embedding": None,
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
run_queries_simple.append(
|
||||
(
|
||||
doc_index_retrieval,
|
||||
(q_copy, document_index, db_session),
|
||||
)
|
||||
)
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries_simple)
|
||||
return combine_retrieval_results(parallel_search_results)
|
||||
|
||||
|
||||
def _retrieve_search_with_multilingual_expansion(
|
||||
db_session: Session,
|
||||
document_index: DocumentIndex,
|
||||
query: SearchQuery,
|
||||
multilingual_expansion: list[str],
|
||||
) -> list[InferenceChunk]:
|
||||
# Use multilingual expansion if configured
|
||||
simplified_queries = set()
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
# Currently only uses query expansion on multilingual use cases
|
||||
query_rephrases = multilingual_query_expansion(query.query, multilingual_expansion)
|
||||
# Just to be extra sure, add the original query.
|
||||
query_rephrases.append(query.query)
|
||||
for rephrase in set(query_rephrases):
|
||||
# Sometimes the model rephrases the query in the same language with minor changes
|
||||
# Avoid doing an extra search with the minor changes as this biases the results
|
||||
simplified_rephrase = _simplify_text(rephrase)
|
||||
if simplified_rephrase in simplified_queries:
|
||||
continue
|
||||
simplified_queries.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.model_copy(
|
||||
update={
|
||||
"query": rephrase,
|
||||
# need to recompute for each rephrase
|
||||
# note that `SearchQuery` is a frozen model, so we can't update
|
||||
# it below
|
||||
"precomputed_query_embedding": None,
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
run_queries.append(
|
||||
(
|
||||
doc_index_retrieval,
|
||||
(q_copy, document_index, db_session),
|
||||
)
|
||||
)
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
return combine_retrieval_results(parallel_search_results)
|
||||
|
||||
|
||||
def retrieve_chunks(
|
||||
query: SearchQuery,
|
||||
document_index: DocumentIndex,
|
||||
@@ -358,48 +474,33 @@ def retrieve_chunks(
|
||||
) -> list[InferenceChunk]:
|
||||
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
|
||||
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
# Don't do query expansion on complex queries, rephrasings likely would not work well
|
||||
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
|
||||
expansion_type, multilingual_expansion, simple_expansion_count = (
|
||||
_get_expansion_type(
|
||||
db_session=db_session,
|
||||
query=query.query,
|
||||
)
|
||||
)
|
||||
|
||||
if expansion_type == ExpansionType.Off:
|
||||
top_chunks = doc_index_retrieval(
|
||||
query=query, document_index=document_index, db_session=db_session
|
||||
)
|
||||
else:
|
||||
simplified_queries = set()
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
# Currently only uses query expansion on multilingual use cases
|
||||
query_rephrases = multilingual_query_expansion(
|
||||
query.query, multilingual_expansion
|
||||
elif expansion_type == ExpansionType.Simple:
|
||||
top_chunks = _retrieve_search_with_simple_expansion(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
query=query,
|
||||
simple_expansion_count=simple_expansion_count,
|
||||
)
|
||||
# Just to be extra sure, add the original query.
|
||||
query_rephrases.append(query.query)
|
||||
for rephrase in set(query_rephrases):
|
||||
# Sometimes the model rephrases the query in the same language with minor changes
|
||||
# Avoid doing an extra search with the minor changes as this biases the results
|
||||
simplified_rephrase = _simplify_text(rephrase)
|
||||
if simplified_rephrase in simplified_queries:
|
||||
continue
|
||||
simplified_queries.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.model_copy(
|
||||
update={
|
||||
"query": rephrase,
|
||||
# need to recompute for each rephrase
|
||||
# note that `SearchQuery` is a frozen model, so we can't update
|
||||
# it below
|
||||
"precomputed_query_embedding": None,
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
run_queries.append(
|
||||
(
|
||||
doc_index_retrieval,
|
||||
(q_copy, document_index, db_session),
|
||||
)
|
||||
)
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
top_chunks = combine_retrieval_results(parallel_search_results)
|
||||
elif expansion_type == ExpansionType.Multilingual:
|
||||
top_chunks = _retrieve_search_with_multilingual_expansion(
|
||||
db_session=db_session,
|
||||
document_index=document_index,
|
||||
query=query,
|
||||
multilingual_expansion=multilingual_expansion,
|
||||
)
|
||||
else:
|
||||
raise RuntimeError
|
||||
|
||||
if not top_chunks:
|
||||
logger.warning(
|
||||
|
||||
@@ -56,6 +56,8 @@ def create_search_settings(
|
||||
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
|
||||
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
|
||||
multilingual_expansion=search_settings.multilingual_expansion,
|
||||
enable_simple_query_expansion=search_settings.enable_simple_query_expansion,
|
||||
simple_expansion_count=search_settings.simple_expansion_count,
|
||||
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
|
||||
rerank_model_name=search_settings.rerank_model_name,
|
||||
rerank_provider_type=search_settings.rerank_provider_type,
|
||||
@@ -195,6 +197,26 @@ def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
|
||||
return search_settings.multilingual_expansion
|
||||
|
||||
|
||||
def get_simple_expansion_settings(
|
||||
db_session: Session | None = None,
|
||||
) -> tuple[bool, int]:
|
||||
"""Get simple query expansion settings from the database.
|
||||
|
||||
Returns:
|
||||
tuple: (enabled, expansion_count)
|
||||
"""
|
||||
# if db_session is None:
|
||||
# with get_session_with_current_tenant() as db_session:
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# else:
|
||||
# search_settings = get_current_search_settings(db_session)
|
||||
# if not search_settings:
|
||||
# return False, 2
|
||||
# return search_settings.enable_simple_query_expansion, search_settings.simple_expansion_count
|
||||
|
||||
return True, 3
|
||||
|
||||
|
||||
def update_search_settings(
|
||||
current_settings: SearchSettings,
|
||||
updated_settings: SavedSearchSettings,
|
||||
|
||||
@@ -9,6 +9,25 @@ Query:
|
||||
{query}
|
||||
""".strip()
|
||||
|
||||
SIMPLE_QUERY_EXPANSION_PROMPT = """
|
||||
Generate {num_expansions} alternative phrasings of the following query in the same language.
|
||||
The goal is to express the same intent in different ways that might retrieve different relevant documents.
|
||||
|
||||
Guidelines:
|
||||
- Keep the same meaning and intent as the original query
|
||||
- Use different vocabulary and sentence structures
|
||||
- Make each expansion distinct from the others
|
||||
- Focus on different aspects or angles of the query
|
||||
- Use synonyms, different verb forms, or alternative phrasings
|
||||
|
||||
Original query:
|
||||
{query}
|
||||
|
||||
Respond with EXACTLY {num_expansions} alternative queries, one per line. Do not include any explanations or numbering.
|
||||
|
||||
Alternative queries:
|
||||
""".strip()
|
||||
|
||||
SLACK_LANGUAGE_REPHRASE_PROMPT = """
|
||||
As an AI assistant employed by an organization, \
|
||||
your role is to transform user messages into concise \
|
||||
@@ -27,3 +46,4 @@ Query:
|
||||
# Use the following for easy viewing of prompts
|
||||
if __name__ == "__main__":
|
||||
print(LANGUAGE_REPHRASE_PROMPT)
|
||||
print(SIMPLE_QUERY_EXPANSION_PROMPT)
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.chat_prompts import HISTORY_QUERY_REPHRASE
|
||||
from onyx.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT
|
||||
from onyx.prompts.miscellaneous_prompts import SIMPLE_QUERY_EXPANSION_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.text_processing import count_punctuation
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -164,3 +165,47 @@ def thread_based_query_rephrase(
|
||||
logger.debug(f"Rephrased combined query: {rephrased_query}")
|
||||
|
||||
return rephrased_query
|
||||
|
||||
|
||||
def simple_unilingual_query_expansion(
|
||||
query: str,
|
||||
num_expansions: int = 2,
|
||||
use_threads: bool = True,
|
||||
) -> list[str]:
|
||||
"""
|
||||
Generate alternative phrasings of a query in the same language.
|
||||
This is useful for improving search recall by trying different ways to express the same intent.
|
||||
"""
|
||||
|
||||
def _get_simple_expansion_messages() -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": SIMPLE_QUERY_EXPANSION_PROMPT.format(
|
||||
query=query, num_expansions=num_expansions
|
||||
),
|
||||
},
|
||||
]
|
||||
return messages
|
||||
|
||||
try:
|
||||
_, fast_llm = get_default_llms(timeout=5)
|
||||
except GenAIDisabledException:
|
||||
logger.warning("Unable to perform simple query expansion, Gen AI disabled")
|
||||
return [query]
|
||||
|
||||
messages = _get_simple_expansion_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(fast_llm.invoke(filled_llm_prompt))
|
||||
logger.debug(f"Simple query expansion output: {model_output}")
|
||||
|
||||
# Parse the response - expect newline-separated queries
|
||||
expanded_queries = [
|
||||
line.strip() for line in model_output.split("\n") if line.strip()
|
||||
]
|
||||
|
||||
# Add the original query to ensure we always have at least one query
|
||||
if query not in expanded_queries:
|
||||
expanded_queries.append(query)
|
||||
|
||||
return expanded_queries[: num_expansions + 1] # +1 to account for original query
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import copy
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
@@ -19,6 +22,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from onyx.chat.prune_and_merge import prune_and_merge_sections
|
||||
from onyx.chat.prune_and_merge import prune_sections
|
||||
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -62,6 +66,39 @@ SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
QUERY_FIELD = "query"
|
||||
|
||||
|
||||
def _append_ranking_stats_to_csv(
|
||||
llm_doc_results: list[tuple[int, str]],
|
||||
query: str,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append ranking statistics to a CSV file.
|
||||
|
||||
Args:
|
||||
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||
csv_path: Path to the CSV file to append to
|
||||
"""
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
for pos, doc in llm_doc_results:
|
||||
writer.writerow([query, pos, doc, ""])
|
||||
|
||||
logger.debug(f"Appended {len(llm_doc_results)} ranking stats to {csv_path}")
|
||||
|
||||
|
||||
class SearchResponseSummary(SearchQueryInfo):
|
||||
top_sections: list[InferenceSection]
|
||||
rephrased_query: str | None = None
|
||||
@@ -502,6 +539,12 @@ def yield_search_responses(
|
||||
)
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
|
||||
|
||||
# Append ranking statistics to a CSV file
|
||||
llm_doc_results = []
|
||||
for pos, doc in enumerate(llm_docs):
|
||||
llm_doc_results.append((pos, doc.document_id))
|
||||
# _append_ranking_stats_to_csv(llm_doc_results, query)
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
|
||||
|
||||
@@ -4,16 +4,12 @@ This Python script evaluates the search results for a list of queries.
|
||||
|
||||
This script will likely get refactored in the future as an API endpoint.
|
||||
In the meanwhile, it is used to evaluate the search quality using locally ingested documents.
|
||||
The key differentiating factor with `answer_quality` is that it can evaluate results without explicit "ground truth" using the reranker as a reference.
|
||||
|
||||
## Usage
|
||||
|
||||
1. Ensure you have the required dependencies installed and onyx running.
|
||||
1. Ensure you have the required dependencies installed and onyx running. Note that auth must be disabled for this script to work (`AUTH_TYPE=disabled`, which is the case by default).
|
||||
|
||||
2. Ensure a reranker model is configured in the search settings.
|
||||
This can be checked/modified by opening the admin panel, going to search settings, and ensuring a reranking model is set.
|
||||
|
||||
3. Set up the PYTHONPATH permanently:
|
||||
2. Set up the PYTHONPATH permanently:
|
||||
Add the following line to your shell configuration file (e.g., `~/.bashrc`, `~/.zshrc`, or `~/.bash_profile`):
|
||||
```
|
||||
export PYTHONPATH=$PYTHONPATH:/path/to/onyx/backend
|
||||
@@ -21,42 +17,41 @@ This can be checked/modified by opening the admin panel, going to search setting
|
||||
Replace `/path/to/onyx` with the actual path to your Onyx repository.
|
||||
After adding this line, restart your terminal or run `source ~/.bashrc` (or the appropriate config file) to apply the changes.
|
||||
|
||||
4. Navigate to Onyx repo, search_quality folder:
|
||||
3. Navigate to Onyx repo, **search_quality** folder:
|
||||
|
||||
```
|
||||
cd path/to/onyx/backend/tests/regression/search_quality
|
||||
```
|
||||
|
||||
5. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The possible fields are:
|
||||
4. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The possible fields are:
|
||||
|
||||
- `question: str` the query
|
||||
- `question_search: Optional[str]` modified query specifically for the search step
|
||||
- `ground_truth: Optional[list[GroundTruth]]` a ranked list of expected search results with fields:
|
||||
- `doc_source: str` document source (e.g., Web, Drive, Linear), currently unused
|
||||
- `doc_source: str` document source (e.g., web, google_drive, linear), used to normalize links in some cases
|
||||
- `doc_link: str` link associated with document, used to find corresponding document in local index
|
||||
- `categories: Optional[list[str]]` list of categories, used to aggregate evaluation results
|
||||
|
||||
6. Copy `search_eval_config.yaml.template` to `search_eval_config.yaml` and specify the search and eval parameters
|
||||
|
||||
7. Run `run_search_eval.py` to run the search and evaluate the search results
|
||||
5. Run `run_search_eval.py` to evaluate the queries. All parameters are optional and have sensible defaults:
|
||||
|
||||
```
|
||||
python run_search_eval.py
|
||||
-d --dataset # Path to the test-set JSON file (default: ./test_queries.json)
|
||||
-n --num_search # Maximum number of search results to check per query (default: 50)
|
||||
-a --num_answer # Maximum number of search results to use for answer evaluation (default: 25)
|
||||
-w --workers # Number of parallel search requests (default: 10)
|
||||
-q --timeout # Request timeout in seconds (default: 120)
|
||||
-e --api_endpoint # Base URL of the Onyx API server (default: http://127.0.0.1:8080)
|
||||
-s --search_only # Only perform search and not answer evaluation (default: false)
|
||||
-r --rerank_all # Always rerank all search results (default: false)
|
||||
-t --tenant_id # Tenant ID to use for the evaluation (default: None)
|
||||
```
|
||||
|
||||
8. Optionally, save the generated `test_queries.json` in the export folder to reuse the generated `question_search`, and rerun the search evaluation with alternative search parameters.
|
||||
6. After the run an `eval-YYYY-MM-DD-HH-MM-SS` folder is created containing:
|
||||
|
||||
## Metrics
|
||||
There are two main metrics currently implemented:
|
||||
- ratio_topk: the ratio of documents in the comparison set that are in the topk search results (higher is better, 0-1)
|
||||
- avg_rank_delta: the average rank difference between the comparison set and search results (lower is better, 0-inf)
|
||||
* `test_queries.json` – the dataset used with the indexed ground truth documents.
|
||||
* `search_results.json` – per-query details.
|
||||
* `results_by_category.csv` – aggregated metrics per category and for "all".
|
||||
* `search_position_chart.png` – bar-chart of ground-truth ranks.
|
||||
|
||||
Ratio topk gives a general idea on whether the most relevant documents are appearing first in the search results. Decreasing `eval_topk` will make this metric stricter, requiring relevant documents to appear in a narrow window.
|
||||
|
||||
Avg rank delta is another metric which can give insight on the performance of documents not in the topk search results. If none of the comparison documents are in the topk, `ratio_topk` will only show a 0, whereas `avg_rank_delta` will show a higher value the worse the search results gets.
|
||||
|
||||
Furthermore, there are two versions of the metrics: ground truth, and soft truth.
|
||||
|
||||
The ground truth includes documents explicitly listed as relevant in the test dataset. The ground truth metrics will only be computed if a ground truth set is provided for the question and exists in the index.
|
||||
|
||||
The soft truth is built on top of the ground truth (if provided), filling the remaining entries with results from the reranker. The soft truth metrics will only be computed if `skip_rerank` is false. Computing the soft truth metric can be extremely slow, especially for large `num_returned_hits`. However, it can provide a good basis when there are many relevant documents in no particular order, or for running quick tests without explicitly having to mention which documents are relevant.
|
||||
You can copy the generated `test_queries.json` back to the root folder to skip the ground truth documents search step.
|
||||
75
backend/tests/regression/search_quality/models.py
Normal file
75
backend/tests/regression/search_quality/models.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
|
||||
|
||||
class GroundTruth(BaseModel):
|
||||
doc_source: DocumentSource
|
||||
doc_link: str
|
||||
|
||||
|
||||
class TestQuery(BaseModel):
|
||||
question: str
|
||||
ground_truth: list[GroundTruth] = []
|
||||
categories: list[str] = []
|
||||
|
||||
# autogenerated
|
||||
ground_truth_docids: list[str] = []
|
||||
|
||||
|
||||
class EvalConfig(BaseModel):
|
||||
max_search_results: int
|
||||
max_answer_context: int
|
||||
num_workers: int
|
||||
request_timeout: int
|
||||
api_url: str
|
||||
search_only: bool
|
||||
rerank_all: bool
|
||||
|
||||
|
||||
class OneshotQAResult(BaseModel):
|
||||
time_taken: float
|
||||
top_documents: list[SavedSearchDoc]
|
||||
answer: str | None
|
||||
|
||||
|
||||
class RetrievedDocument(BaseModel):
|
||||
document_id: str
|
||||
content: str
|
||||
|
||||
|
||||
class AnalysisSummary(BaseModel):
|
||||
question: str
|
||||
categories: list[str]
|
||||
found: bool
|
||||
rank: int | None
|
||||
total_results: int
|
||||
ground_truth_count: int
|
||||
answer: str | None = None
|
||||
response_relevancy: float | None = None
|
||||
response_groundedness: float | None = None
|
||||
faithfulness: float | None = None
|
||||
retrieved: list[RetrievedDocument] = []
|
||||
time_taken: float | None = None
|
||||
|
||||
|
||||
class SearchMetrics(BaseModel):
|
||||
total_queries: int
|
||||
found_count: int
|
||||
|
||||
# for found results
|
||||
best_rank: int
|
||||
worst_rank: int
|
||||
average_rank: float
|
||||
top_k_accuracy: dict[int, float]
|
||||
|
||||
|
||||
class AnswerMetrics(BaseModel):
|
||||
average_response_relevancy: float
|
||||
average_response_groundedness: float
|
||||
average_faithfulness: float
|
||||
|
||||
|
||||
class CombinedMetrics(SearchMetrics, AnswerMetrics):
|
||||
average_time_taken: float
|
||||
@@ -1,151 +1,692 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import requests
|
||||
from pydantic import ValidationError
|
||||
from requests.exceptions import RequestException
|
||||
from retry import retry
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import OneShotQARequest
|
||||
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
|
||||
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from tests.regression.search_quality.util_config import load_config
|
||||
from tests.regression.search_quality.util_data import export_test_queries
|
||||
from tests.regression.search_quality.util_data import load_test_queries
|
||||
from tests.regression.search_quality.util_eval import evaluate_one_query
|
||||
from tests.regression.search_quality.util_eval import get_corresponding_document
|
||||
from tests.regression.search_quality.util_eval import metric_names
|
||||
from tests.regression.search_quality.util_retrieve import rerank_one_query
|
||||
from tests.regression.search_quality.util_retrieve import search_one_query
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
from tests.regression.search_quality.models import AnalysisSummary
|
||||
from tests.regression.search_quality.models import CombinedMetrics
|
||||
from tests.regression.search_quality.models import EvalConfig
|
||||
from tests.regression.search_quality.models import OneshotQAResult
|
||||
from tests.regression.search_quality.models import TestQuery
|
||||
from tests.regression.search_quality.utils import find_document
|
||||
from tests.regression.search_quality.utils import ragas_evaluate
|
||||
from tests.regression.search_quality.utils import search_docs_to_doc_contexts
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
GENERAL_HEADERS = {"Content-Type": "application/json"}
|
||||
TOP_K_LIST = [1, 3, 5, 10]
|
||||
|
||||
def run_search_eval() -> None:
|
||||
config = load_config()
|
||||
test_queries = load_test_queries()
|
||||
|
||||
# export related
|
||||
export_path = Path(config.export_folder)
|
||||
export_test_queries(test_queries, export_path / "test_queries.json")
|
||||
search_result_path = export_path / "search_results.csv"
|
||||
eval_path = export_path / "eval_results.csv"
|
||||
aggregate_eval_path = export_path / "aggregate_eval.csv"
|
||||
aggregate_results: dict[str, list[list[float]]] = defaultdict(
|
||||
lambda: [[] for _ in metric_names]
|
||||
)
|
||||
class SearchAnswerAnalyzer:
|
||||
def __init__(
|
||||
self,
|
||||
config: EvalConfig,
|
||||
tenant_id: str | None = None,
|
||||
):
|
||||
if not MULTI_TENANT:
|
||||
logger.info("Running in single-tenant mode")
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
elif tenant_id is None:
|
||||
raise ValueError("Tenant ID is required for multi-tenant")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
multilingual_expansion = get_multilingual_expansion(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
document_index = get_default_document_index(search_settings, None)
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
# rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
self.config = config
|
||||
self.tenant_id = tenant_id
|
||||
self.results: list[AnalysisSummary] = []
|
||||
self.stats: dict[str, list[AnalysisSummary]] = defaultdict(list)
|
||||
self.metrics: dict[str, CombinedMetrics] = {}
|
||||
|
||||
if config.skip_rerank:
|
||||
logger.warning("Reranking is disabled, evaluation will not run")
|
||||
elif rerank_settings.rerank_model_name is None:
|
||||
raise ValueError(
|
||||
"Reranking is enabled but no reranker is configured. "
|
||||
"Please set the reranker in the admin panel search settings."
|
||||
)
|
||||
# get search related settings
|
||||
self._rerank_settings = self._get_rerank_settings()
|
||||
|
||||
# run search and evaluate
|
||||
logger.info(
|
||||
"Running search and evaluation... "
|
||||
f"Individual search and evaluation results will be saved to {search_result_path} and {eval_path}"
|
||||
def run_analysis(self, dataset_path: Path, export_path: Path) -> None:
|
||||
# load and save the dataset
|
||||
dataset = self._load_dataset(dataset_path)
|
||||
dataset_size = len(dataset)
|
||||
|
||||
# export the processed dataset
|
||||
dataset_export_path = export_path / "test_queries.json"
|
||||
with dataset_export_path.open("w") as f:
|
||||
dataset_serializable = [q.model_dump(mode="json") for q in dataset]
|
||||
json.dump(dataset_serializable, f, indent=4)
|
||||
|
||||
# run the analysis
|
||||
logger.info("Starting analysis of %d queries...", dataset_size)
|
||||
logger.info("Using %d parallel workers", self.config.num_workers)
|
||||
|
||||
indexed_test_cases = [(i, test_case) for i, test_case in enumerate(dataset)]
|
||||
indexed_results: dict[int, AnalysisSummary] = {}
|
||||
with ThreadPoolExecutor(max_workers=self.config.num_workers) as executor:
|
||||
future_to_index = {
|
||||
executor.submit(
|
||||
self._run_and_analyze_one_wrapper, test_case_with_index
|
||||
): test_case_with_index[0]
|
||||
for test_case_with_index in indexed_test_cases
|
||||
}
|
||||
|
||||
# process completed tasks as they finish
|
||||
for completed_count, future in enumerate(as_completed(future_to_index), 1):
|
||||
try:
|
||||
index, result = future.result()
|
||||
indexed_results[index] = result
|
||||
|
||||
# update category stats on the fly
|
||||
self.stats["all"].append(result)
|
||||
for cat in result.categories or ["uncategorized"]:
|
||||
self.stats[cat].append(result)
|
||||
except Exception as e:
|
||||
print(f"[{completed_count}/{dataset_size}] ✗ Error: {e}")
|
||||
continue
|
||||
|
||||
# print progress with query info
|
||||
question = (
|
||||
result.question[:50] + "..."
|
||||
if len(result.question) > 50
|
||||
else result.question
|
||||
)
|
||||
status = "✓ Found" if result.found else "✗ Not found"
|
||||
rank_info = f" (rank {result.rank})" if result.found else ""
|
||||
print(
|
||||
f"[{completed_count}/{dataset_size}] {status}{rank_info}: {question}"
|
||||
)
|
||||
|
||||
# sort results by original order and build the metrics
|
||||
self.results = [indexed_results[i] for i in sorted(indexed_results.keys())]
|
||||
self._build_metrics()
|
||||
|
||||
def generate_summary(self) -> None:
|
||||
logger.info("Generating summary...")
|
||||
|
||||
metrics_all = self.metrics.get("all")
|
||||
if metrics_all is None:
|
||||
logger.warning("Nothing to summarize")
|
||||
return
|
||||
|
||||
total_queries = metrics_all.total_queries
|
||||
found_count = metrics_all.found_count
|
||||
|
||||
print(
|
||||
f"Total test queries: {total_queries}\n"
|
||||
f"Ground truth found: {found_count} "
|
||||
f"({found_count / total_queries * 100:.1f}%)\n"
|
||||
f"Ground truth not found: {total_queries - found_count} "
|
||||
f"({(total_queries - found_count) / total_queries * 100:.1f}%)"
|
||||
)
|
||||
with (
|
||||
search_result_path.open("w") as search_file,
|
||||
eval_path.open("w") as eval_file,
|
||||
):
|
||||
search_csv_writer = csv.writer(search_file)
|
||||
eval_csv_writer = csv.writer(eval_file)
|
||||
search_csv_writer.writerow(
|
||||
["source", "query", "rank", "score", "doc_id", "chunk_id"]
|
||||
|
||||
if metrics_all.found_count > 0:
|
||||
print(
|
||||
"\nRank statistics (for found results):\n"
|
||||
f" Average rank: {metrics_all.average_rank:.2f}\n"
|
||||
f" Best rank: {metrics_all.best_rank}\n"
|
||||
f" Worst rank: {metrics_all.worst_rank}\n"
|
||||
)
|
||||
eval_csv_writer.writerow(["query", *metric_names])
|
||||
for k, acc in metrics_all.top_k_accuracy.items():
|
||||
print(f" Top-{k} accuracy: {acc:.1f}%")
|
||||
print(f"Average time taken: {metrics_all.average_time_taken:.2f}s")
|
||||
|
||||
for query in test_queries:
|
||||
# search and write results
|
||||
assert query.question_search is not None
|
||||
search_chunks = search_one_query(
|
||||
query.question_search,
|
||||
multilingual_expansion,
|
||||
document_index,
|
||||
db_session,
|
||||
config,
|
||||
)
|
||||
for rank, result in enumerate(search_chunks):
|
||||
search_csv_writer.writerow(
|
||||
if not self.config.search_only:
|
||||
print(
|
||||
"\nAnswer evaluation metrics:\n"
|
||||
f" Average response relevancy: {metrics_all.average_response_relevancy:.2f}\n"
|
||||
f" Average response groundedness: {metrics_all.average_response_groundedness:.2f}\n"
|
||||
f" Average faithfulness: {metrics_all.average_faithfulness:.2f}"
|
||||
)
|
||||
|
||||
def generate_detailed_report(self, export_path: Path) -> None:
|
||||
logger.info("Generating detailed report...")
|
||||
|
||||
# save results for future inspection
|
||||
results_json_path = export_path / "search_results.json"
|
||||
with results_json_path.open("w") as f:
|
||||
json.dump([r.model_dump(mode="json") for r in self.results], f, indent=4)
|
||||
logger.info("Saved search results to %s", results_json_path)
|
||||
|
||||
# save results by category
|
||||
csv_path = export_path / "results_by_category.csv"
|
||||
with csv_path.open("w", newline="") as csv_file:
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(
|
||||
[
|
||||
"category",
|
||||
"total_queries",
|
||||
"found",
|
||||
"percent_found",
|
||||
"avg_rank_found",
|
||||
*(
|
||||
[
|
||||
"search",
|
||||
query.question_search,
|
||||
rank,
|
||||
result.score,
|
||||
result.document_id,
|
||||
result.chunk_id,
|
||||
"avg_response_relevancy",
|
||||
"avg_response_groundedness",
|
||||
"avg_faithfulness",
|
||||
]
|
||||
)
|
||||
|
||||
rerank_chunks = []
|
||||
if not config.skip_rerank:
|
||||
# rerank and write results
|
||||
rerank_chunks = rerank_one_query(
|
||||
query.question, search_chunks, rerank_settings
|
||||
)
|
||||
for rank, result in enumerate(rerank_chunks):
|
||||
search_csv_writer.writerow(
|
||||
[
|
||||
"rerank",
|
||||
query.question,
|
||||
rank,
|
||||
result.score,
|
||||
result.document_id,
|
||||
result.chunk_id,
|
||||
]
|
||||
)
|
||||
|
||||
# evaluate and write results
|
||||
truth_documents = [
|
||||
doc
|
||||
for truth in query.ground_truth
|
||||
if (doc := get_corresponding_document(truth.doc_link, db_session))
|
||||
if not self.config.search_only
|
||||
else []
|
||||
),
|
||||
"avg_time_taken_sec",
|
||||
]
|
||||
metrics = evaluate_one_query(
|
||||
search_chunks, rerank_chunks, truth_documents, config.eval_topk
|
||||
)
|
||||
|
||||
for category, results in sorted(self.stats.items()):
|
||||
if not results:
|
||||
continue
|
||||
|
||||
metrics = self.metrics[category]
|
||||
found_count = metrics.found_count
|
||||
total_count = metrics.total_queries
|
||||
accuracy = found_count / total_count * 100 if total_count > 0 else 0
|
||||
|
||||
print(
|
||||
f"\n{category.upper()}:"
|
||||
f" total queries: {total_count}\n"
|
||||
f" found: {found_count} ({accuracy:.1f}%)"
|
||||
)
|
||||
metric_vals = [getattr(metrics, field) for field in metric_names]
|
||||
eval_csv_writer.writerow([query.question, *metric_vals])
|
||||
avg_rank = metrics.average_rank if metrics.found_count > 0 else None
|
||||
if avg_rank is not None:
|
||||
print(f" average rank (when found): {avg_rank:.2f}")
|
||||
if not self.config.search_only:
|
||||
print(
|
||||
f" average response relevancy: {metrics.average_response_relevancy:.2f}\n"
|
||||
f" average response groundedness: {metrics.average_response_groundedness:.2f}\n"
|
||||
f" average faithfulness: {metrics.average_faithfulness:.2f}"
|
||||
)
|
||||
print(f" average time taken: {metrics.average_time_taken:.2f}s")
|
||||
|
||||
# add to aggregation
|
||||
for category in ["all"] + query.categories:
|
||||
for i, val in enumerate(metric_vals):
|
||||
if val is not None:
|
||||
aggregate_results[category][i].append(val)
|
||||
|
||||
# aggregate and write results
|
||||
with aggregate_eval_path.open("w") as file:
|
||||
aggregate_csv_writer = csv.writer(file)
|
||||
aggregate_csv_writer.writerow(["category", *metric_names])
|
||||
|
||||
for category, agg_metrics in aggregate_results.items():
|
||||
aggregate_csv_writer.writerow(
|
||||
csv_writer.writerow(
|
||||
[
|
||||
category,
|
||||
total_count,
|
||||
found_count,
|
||||
f"{accuracy:.1f}",
|
||||
f"{avg_rank:.2f}" if avg_rank is not None else "",
|
||||
*(
|
||||
sum(metric) / len(metric) if metric else None
|
||||
for metric in agg_metrics
|
||||
[
|
||||
f"{metrics.average_response_relevancy:.2f}",
|
||||
f"{metrics.average_response_groundedness:.2f}",
|
||||
f"{metrics.average_faithfulness:.2f}",
|
||||
]
|
||||
if not self.config.search_only
|
||||
else []
|
||||
),
|
||||
f"{metrics.average_time_taken:.2f}",
|
||||
]
|
||||
)
|
||||
logger.info("Saved category breakdown csv to %s", csv_path)
|
||||
|
||||
def generate_chart(self, export_path: Path) -> None:
|
||||
logger.info("Generating search position chart...")
|
||||
|
||||
found_results = [r for r in self.results if r.found]
|
||||
not_found_count = len([r for r in self.results if not r.found])
|
||||
|
||||
if not found_results and not_found_count == 0:
|
||||
logger.warning("No results to chart")
|
||||
return
|
||||
|
||||
# count occurrences at each rank position
|
||||
rank_counts: dict[int, int] = defaultdict(int)
|
||||
for result in found_results:
|
||||
if result.rank is not None:
|
||||
rank_counts[result.rank] += 1
|
||||
|
||||
# create the data for plotting
|
||||
if found_results:
|
||||
max_rank = max(rank_counts.keys())
|
||||
positions = list(range(1, max_rank + 1))
|
||||
counts = [rank_counts.get(pos, 0) for pos in positions]
|
||||
else:
|
||||
positions = []
|
||||
counts = []
|
||||
|
||||
# add the "not found" bar on the far right
|
||||
if not_found_count > 0:
|
||||
# add some spacing between found positions and "not found"
|
||||
not_found_position = (max(positions) + 2) if positions else 1
|
||||
positions.append(not_found_position)
|
||||
counts.append(not_found_count)
|
||||
|
||||
# create labels for x-axis
|
||||
x_labels = [str(pos) for pos in positions[:-1]] + [
|
||||
f"not found\n(>{self.config.max_search_results})"
|
||||
]
|
||||
else:
|
||||
x_labels = [str(pos) for pos in positions]
|
||||
|
||||
# create the figure and bar chart
|
||||
plt.figure(figsize=(14, 6))
|
||||
|
||||
# use different colors for found vs not found
|
||||
colors = (
|
||||
["#3498db"] * (len(positions) - 1) + ["#e74c3c"]
|
||||
if not_found_count > 0
|
||||
else ["#3498db"] * len(positions)
|
||||
)
|
||||
bars = plt.bar(
|
||||
positions, counts, color=colors, alpha=0.7, edgecolor="black", linewidth=0.5
|
||||
)
|
||||
|
||||
# customize the chart
|
||||
plt.xlabel("Position in Search Results", fontsize=12)
|
||||
plt.ylabel("Number of Ground Truth Documents", fontsize=12)
|
||||
plt.title(
|
||||
"Ground Truth Document Positions in Search Results",
|
||||
fontsize=14,
|
||||
fontweight="bold",
|
||||
)
|
||||
plt.grid(axis="y", alpha=0.3)
|
||||
|
||||
# add value labels on top of each bar
|
||||
for bar, count in zip(bars, counts):
|
||||
if count > 0:
|
||||
plt.text(
|
||||
bar.get_x() + bar.get_width() / 2,
|
||||
bar.get_height() + 0.1,
|
||||
str(count),
|
||||
ha="center",
|
||||
va="bottom",
|
||||
fontweight="bold",
|
||||
)
|
||||
|
||||
# set x-axis labels
|
||||
plt.xticks(positions, x_labels, rotation=45 if not_found_count > 0 else 0)
|
||||
|
||||
# add legend if we have both found and not found
|
||||
if not_found_count > 0 and found_results:
|
||||
from matplotlib.patches import Patch
|
||||
|
||||
legend_elements = [
|
||||
Patch(facecolor="#3498db", alpha=0.7, label="Found in Results"),
|
||||
Patch(facecolor="#e74c3c", alpha=0.7, label="Not Found"),
|
||||
]
|
||||
plt.legend(handles=legend_elements, loc="upper right")
|
||||
|
||||
# make layout tight and save
|
||||
plt.tight_layout()
|
||||
|
||||
chart_file = export_path / "search_position_chart.png"
|
||||
plt.savefig(chart_file, dpi=300, bbox_inches="tight")
|
||||
logger.info("Search position chart saved to: %s", chart_file)
|
||||
|
||||
plt.show()
|
||||
|
||||
def _load_dataset(self, dataset_path: Path) -> list[TestQuery]:
|
||||
"""Load the test dataset from a JSON file and validate the ground truth documents."""
|
||||
with dataset_path.open("r") as f:
|
||||
dataset_raw: list[dict] = json.load(f)
|
||||
|
||||
dataset: list[TestQuery] = []
|
||||
for datum in dataset_raw:
|
||||
# validate the raw datum
|
||||
try:
|
||||
test_query = TestQuery(**datum)
|
||||
except ValidationError as e:
|
||||
logger.error("Incorrectly formatted query: %s", e)
|
||||
continue
|
||||
|
||||
# in case the dataset was copied from the previous run export
|
||||
if test_query.ground_truth_docids:
|
||||
dataset.append(test_query)
|
||||
continue
|
||||
|
||||
# validate and get the ground truth documents
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
for ground_truth in test_query.ground_truth:
|
||||
doc = find_document(ground_truth, db_session)
|
||||
if doc:
|
||||
test_query.ground_truth_docids.append(doc.id)
|
||||
|
||||
if len(test_query.ground_truth_docids) == 0:
|
||||
logger.warning(
|
||||
"No ground truth documents found for query: %s, skipping...",
|
||||
test_query.question,
|
||||
)
|
||||
continue
|
||||
|
||||
dataset.append(test_query)
|
||||
|
||||
return dataset
|
||||
|
||||
@retry(tries=3, delay=1, backoff=2)
|
||||
def _perform_oneshot_qa(self, query: str) -> OneshotQAResult:
|
||||
"""Perform a OneShot QA query against the Onyx API and time it."""
|
||||
# create the thread message
|
||||
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
|
||||
|
||||
# create filters (empty to search all sources)
|
||||
filters = IndexFilters(
|
||||
source_type=None,
|
||||
document_set=None,
|
||||
time_cutoff=None,
|
||||
tags=None,
|
||||
access_control_list=None,
|
||||
)
|
||||
|
||||
# create the OneShot QA request
|
||||
qa_request = OneShotQARequest(
|
||||
messages=messages,
|
||||
prompt_id=0, # default prompt
|
||||
persona_id=0, # default persona
|
||||
retrieval_options=RetrievalDetails(
|
||||
run_search=OptionalSearchSetting.ALWAYS,
|
||||
real_time=True,
|
||||
filters=filters,
|
||||
enable_auto_detect_filters=False,
|
||||
limit=self.config.max_search_results,
|
||||
),
|
||||
rerank_settings=self._rerank_settings,
|
||||
return_contexts=True,
|
||||
skip_gen_ai_answer_generation=self.config.search_only,
|
||||
)
|
||||
|
||||
# send the request
|
||||
response = None
|
||||
try:
|
||||
request_data = qa_request.model_dump()
|
||||
|
||||
start_time = time.monotonic()
|
||||
response = requests.post(
|
||||
url=f"{self.config.api_url}/query/answer-with-citation",
|
||||
json=request_data,
|
||||
headers=GENERAL_HEADERS,
|
||||
timeout=self.config.request_timeout,
|
||||
)
|
||||
time_taken = time.monotonic() - start_time
|
||||
response.raise_for_status()
|
||||
result = OneShotQAResponse.model_validate(response.json())
|
||||
|
||||
# extract documents from the QA response
|
||||
if result.docs:
|
||||
top_documents = result.docs.top_documents
|
||||
return OneshotQAResult(
|
||||
time_taken=time_taken,
|
||||
top_documents=top_documents,
|
||||
answer=result.answer,
|
||||
)
|
||||
except RequestException as e:
|
||||
raise RuntimeError(
|
||||
f"OneShot QA failed for query '{query}': {e}."
|
||||
f" Response: {response.json()}"
|
||||
if response
|
||||
else ""
|
||||
)
|
||||
raise RuntimeError(f"OneShot QA returned no documents for query {query}")
|
||||
|
||||
def _run_and_analyze_one(self, test_case: TestQuery) -> AnalysisSummary:
|
||||
result = self._perform_oneshot_qa(test_case.question)
|
||||
|
||||
# compute rank
|
||||
rank = self.config.max_search_results
|
||||
found = False
|
||||
ground_truths = set(test_case.ground_truth_docids)
|
||||
for i, doc in enumerate(result.top_documents, 1):
|
||||
if doc.document_id in ground_truths:
|
||||
rank = i
|
||||
found = True
|
||||
break
|
||||
|
||||
# get the search contents
|
||||
retrieved = search_docs_to_doc_contexts(result.top_documents)
|
||||
|
||||
# do answer evaluation
|
||||
response_relevancy: float | None = None
|
||||
response_groundedness: float | None = None
|
||||
faithfulness: float | None = None
|
||||
contexts = [c.content for c in retrieved[: self.config.max_answer_context]]
|
||||
if not self.config.search_only:
|
||||
if result.answer is None:
|
||||
logger.error(
|
||||
"No answer found for query: %s, skipping answer evaluation",
|
||||
test_case.question,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
ragas_result = ragas_evaluate(
|
||||
question=test_case.question,
|
||||
answer=result.answer,
|
||||
contexts=contexts,
|
||||
).scores[0]
|
||||
response_relevancy = ragas_result["answer_relevancy"]
|
||||
response_groundedness = ragas_result["nv_response_groundedness"]
|
||||
faithfulness = ragas_result["faithfulness"]
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
"Error evaluating answer for query %s: %s",
|
||||
test_case.question,
|
||||
e,
|
||||
)
|
||||
|
||||
return AnalysisSummary(
|
||||
question=test_case.question,
|
||||
categories=test_case.categories,
|
||||
found=found,
|
||||
rank=rank,
|
||||
total_results=len(result.top_documents),
|
||||
ground_truth_count=len(test_case.ground_truth_docids),
|
||||
answer=result.answer,
|
||||
response_relevancy=response_relevancy,
|
||||
response_groundedness=response_groundedness,
|
||||
faithfulness=faithfulness,
|
||||
retrieved=retrieved,
|
||||
time_taken=result.time_taken,
|
||||
)
|
||||
|
||||
def _run_and_analyze_one_wrapper(
|
||||
self, test_case_with_index: tuple[int, TestQuery]
|
||||
) -> tuple[int, AnalysisSummary]:
|
||||
index, test_case = test_case_with_index
|
||||
return index, self._run_and_analyze_one(test_case)
|
||||
|
||||
def _compute_combined_metrics(
|
||||
self, results: list[AnalysisSummary]
|
||||
) -> CombinedMetrics:
|
||||
"""Aggregate analysis summaries into CombinedMetrics."""
|
||||
|
||||
total_queries = len(results)
|
||||
found_ranks = [r.rank for r in results if r.found and r.rank is not None]
|
||||
found_count = len(found_ranks)
|
||||
best_rank = 0
|
||||
worst_rank = 0
|
||||
average_rank = 0.0
|
||||
response_relevancy = 0.0
|
||||
response_groundedness = 0.0
|
||||
faithfulness = 0.0
|
||||
|
||||
if found_ranks:
|
||||
best_rank = min(found_ranks)
|
||||
worst_rank = max(found_ranks)
|
||||
average_rank = sum(found_ranks) / found_count
|
||||
|
||||
if not self.config.search_only:
|
||||
scores = [
|
||||
r.response_relevancy
|
||||
for r in results
|
||||
if r.response_relevancy is not None
|
||||
]
|
||||
response_relevancy = sum(scores) / len(scores)
|
||||
scores = [
|
||||
r.response_groundedness
|
||||
for r in results
|
||||
if r.response_groundedness is not None
|
||||
]
|
||||
response_groundedness = sum(scores) / len(scores)
|
||||
scores = [r.faithfulness for r in results if r.faithfulness is not None]
|
||||
faithfulness = sum(scores) / len(scores)
|
||||
|
||||
top_k_accuracy: dict[int, float] = {}
|
||||
for k in TOP_K_LIST:
|
||||
hits = sum(1 for rank in found_ranks if rank <= k)
|
||||
top_k_accuracy[k] = (hits / total_queries * 100) if total_queries else 0.0
|
||||
|
||||
times = [r.time_taken for r in results if r.time_taken is not None]
|
||||
avg_time_taken = sum(times) / len(times) if times else 0.0
|
||||
|
||||
return CombinedMetrics(
|
||||
total_queries=total_queries,
|
||||
found_count=found_count,
|
||||
best_rank=best_rank,
|
||||
worst_rank=worst_rank,
|
||||
average_rank=average_rank,
|
||||
top_k_accuracy=top_k_accuracy,
|
||||
average_response_relevancy=response_relevancy,
|
||||
average_response_groundedness=response_groundedness,
|
||||
average_faithfulness=faithfulness,
|
||||
average_time_taken=avg_time_taken,
|
||||
)
|
||||
|
||||
def _build_metrics(self) -> None:
|
||||
self.metrics = {
|
||||
cat: self._compute_combined_metrics(res_list)
|
||||
for cat, res_list in self.stats.items()
|
||||
}
|
||||
|
||||
def _get_rerank_settings(self) -> RerankingDetails | None:
|
||||
"""Fetch the tenant's reranking settings from the database."""
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if search_settings:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
if not self.config.rerank_all:
|
||||
return rerank_settings
|
||||
|
||||
# override the num_rerank to the eval limit
|
||||
rerank_settings = rerank_settings.model_copy(
|
||||
update={"num_rerank": self.config.max_search_results}
|
||||
)
|
||||
return rerank_settings
|
||||
except Exception as e:
|
||||
logger.warning("Could not load rerank settings from DB: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def run_search_eval(
|
||||
dataset_path: Path,
|
||||
config: EvalConfig,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
if not config.search_only and not os.environ.get("OPENAI_API_KEY"):
|
||||
raise RuntimeError("OPENAI_API_KEY is required for answer evaluation")
|
||||
|
||||
# check onyx is running
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{config.api_url}/health", timeout=config.request_timeout
|
||||
)
|
||||
response.raise_for_status()
|
||||
except RequestException as e:
|
||||
raise RuntimeError(f"Could not connect to Onyx API: {e}")
|
||||
logger.info("Onyx API is running")
|
||||
|
||||
# create the export folder
|
||||
export_folder = current_dir / datetime.now().strftime("eval-%Y-%m-%d-%H-%M-%S")
|
||||
export_path = Path(export_folder)
|
||||
export_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info("Created export folder: %s", export_path)
|
||||
|
||||
# run the search eval
|
||||
analyzer = SearchAnswerAnalyzer(config=config, tenant_id=tenant_id)
|
||||
analyzer.run_analysis(dataset_path, export_path)
|
||||
analyzer.generate_summary()
|
||||
analyzer.generate_detailed_report(export_path)
|
||||
analyzer.generate_chart(export_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if MULTI_TENANT:
|
||||
raise ValueError("Multi-tenant is not supported currently")
|
||||
import argparse
|
||||
|
||||
current_dir = Path(__file__).parent
|
||||
parser = argparse.ArgumentParser(description="Run search quality evaluation.")
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--dataset",
|
||||
type=Path,
|
||||
default=current_dir / "test_queries.json",
|
||||
help="Path to the test-set JSON file (default: %(default)s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-n",
|
||||
"--num_search",
|
||||
type=int,
|
||||
default=50,
|
||||
help="Maximum number of search results to check per query (default: %(default)s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-a",
|
||||
"--num_answer",
|
||||
type=int,
|
||||
default=25,
|
||||
help="Maximum number of search results to use for answer evaluation (default: %(default)s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-w",
|
||||
"--workers",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of parallel search requests (default: %(default)s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-q",
|
||||
"--timeout",
|
||||
type=int,
|
||||
default=120,
|
||||
help="Request timeout in seconds (default: %(default)s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-e",
|
||||
"--api_endpoint",
|
||||
type=str,
|
||||
default="http://127.0.0.1:8080",
|
||||
help="Base URL of the Onyx API server (default: %(default)s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-s",
|
||||
"--search_only",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Only perform search and not answer evaluation (default: %(default)s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-r",
|
||||
"--rerank_all",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Always rerank all search results (default: %(default)s).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tenant_id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Tenant ID to use for the evaluation (default: %(default)s).",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
SqlEngine.init_engine(
|
||||
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
|
||||
@@ -153,9 +694,21 @@ if __name__ == "__main__":
|
||||
)
|
||||
|
||||
try:
|
||||
run_search_eval()
|
||||
run_search_eval(
|
||||
args.dataset,
|
||||
EvalConfig(
|
||||
max_search_results=args.num_search,
|
||||
max_answer_context=args.num_answer,
|
||||
num_workers=args.workers,
|
||||
request_timeout=args.timeout,
|
||||
api_url=args.api_endpoint,
|
||||
search_only=args.search_only,
|
||||
rerank_all=args.rerank_all,
|
||||
),
|
||||
args.tenant_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error running search evaluation: {e}")
|
||||
raise e
|
||||
logger.error("Unexpected error during search evaluation: %s", e)
|
||||
raise
|
||||
finally:
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# Search Parameters
|
||||
HYBRID_ALPHA: 0.5
|
||||
HYBRID_ALPHA_KEYWORD: 0.4
|
||||
DOC_TIME_DECAY: 0.5
|
||||
NUM_RETURNED_HITS: 50 # Setting to a higher value will improve evaluation quality but increase reranking time
|
||||
RANK_PROFILE: 'semantic'
|
||||
OFFSET: 0
|
||||
TITLE_CONTENT_RATIO: 0.1
|
||||
USER_EMAIL: null # User email to use for testing, modifies access control list, null means only public files
|
||||
|
||||
# Evaluation parameters
|
||||
SKIP_RERANK: false # Whether to skip reranking, reranking must be enabled to evaluate the search results
|
||||
EVAL_TOPK: 5 # Number of top results from the searcher and reranker to evaluate, lower means stricter evaluation
|
||||
|
||||
# Export file, will export a csv file with the results and a json file with the parameters
|
||||
EXPORT_FOLDER: "eval-%Y-%m-%d-%H-%M-%S"
|
||||
@@ -3,20 +3,18 @@
|
||||
"question": "What is Onyx?",
|
||||
"ground_truth": [
|
||||
{
|
||||
"doc_source": "Web",
|
||||
"doc_source": "web",
|
||||
"doc_link": "https://docs.onyx.app/more/use_cases/overview"
|
||||
},
|
||||
{
|
||||
"doc_source": "Web",
|
||||
"doc_source": "web",
|
||||
"doc_link": "https://docs.onyx.app/more/use_cases/ai_platform"
|
||||
}
|
||||
],
|
||||
"categories": [
|
||||
"keyword",
|
||||
"broad"
|
||||
"broad",
|
||||
"easy"
|
||||
]
|
||||
},
|
||||
{
|
||||
"question": "What is the meaning of life?"
|
||||
}
|
||||
]
|
||||
@@ -1,75 +0,0 @@
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
|
||||
from onyx.configs.chat_configs import DOC_TIME_DECAY
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA
|
||||
from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD
|
||||
from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class SearchEvalConfig(BaseModel):
|
||||
hybrid_alpha: float
|
||||
hybrid_alpha_keyword: float
|
||||
doc_time_decay: float
|
||||
num_returned_hits: int
|
||||
rank_profile: QueryExpansionType
|
||||
offset: int
|
||||
title_content_ratio: float
|
||||
user_email: str | None
|
||||
skip_rerank: bool
|
||||
eval_topk: int
|
||||
export_folder: str
|
||||
|
||||
|
||||
def load_config() -> SearchEvalConfig:
|
||||
"""Loads the search evaluation configs from the config file."""
|
||||
# open the config file
|
||||
current_dir = Path(__file__).parent
|
||||
config_path = current_dir / "search_eval_config.yaml"
|
||||
if not config_path.exists():
|
||||
raise FileNotFoundError(f"Search eval config file not found at {config_path}")
|
||||
with config_path.open("r") as file:
|
||||
config_raw = yaml.safe_load(file)
|
||||
|
||||
# create the export folder
|
||||
export_folder = config_raw.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S")
|
||||
export_folder = datetime.now().strftime(export_folder)
|
||||
export_path = Path(export_folder)
|
||||
export_path.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(f"Created export folder: {export_path}")
|
||||
|
||||
# create the config
|
||||
config = SearchEvalConfig(
|
||||
hybrid_alpha=config_raw.get("HYBRID_ALPHA", HYBRID_ALPHA),
|
||||
hybrid_alpha_keyword=config_raw.get(
|
||||
"HYBRID_ALPHA_KEYWORD", HYBRID_ALPHA_KEYWORD
|
||||
),
|
||||
doc_time_decay=config_raw.get("DOC_TIME_DECAY", DOC_TIME_DECAY),
|
||||
num_returned_hits=config_raw.get("NUM_RETURNED_HITS", NUM_RETURNED_HITS),
|
||||
rank_profile=config_raw.get("RANK_PROFILE", QueryExpansionType.SEMANTIC),
|
||||
offset=config_raw.get("OFFSET", 0),
|
||||
title_content_ratio=config_raw.get("TITLE_CONTENT_RATIO", TITLE_CONTENT_RATIO),
|
||||
user_email=config_raw.get("USER_EMAIL"),
|
||||
skip_rerank=config_raw.get("SKIP_RERANK", False),
|
||||
eval_topk=config_raw.get("EVAL_TOPK", 5),
|
||||
export_folder=export_folder,
|
||||
)
|
||||
logger.info(f"Using search parameters: {config}")
|
||||
|
||||
# export the config
|
||||
config_file = export_path / "search_eval_config.yaml"
|
||||
with config_file.open("w") as file:
|
||||
config_dict = config.model_dump(mode="python")
|
||||
config_dict["rank_profile"] = config.rank_profile.value
|
||||
yaml.dump(config_dict, file, sort_keys=False)
|
||||
logger.info(f"Exported config to {config_file}")
|
||||
|
||||
return config
|
||||
@@ -1,166 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from langgraph.types import StreamWriter
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ValidationError
|
||||
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.persona import get_persona_by_id
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GroundTruth(BaseModel):
|
||||
doc_source: str
|
||||
doc_link: str
|
||||
|
||||
|
||||
class TestQuery(BaseModel):
|
||||
question: str
|
||||
question_search: Optional[str] = None
|
||||
ground_truth: list[GroundTruth] = []
|
||||
categories: list[str] = []
|
||||
|
||||
|
||||
def load_test_queries() -> list[TestQuery]:
|
||||
"""
|
||||
Loads the test queries from the test_queries.json file.
|
||||
If `question_search` is missing, it will use the tool-calling LLM to generate it.
|
||||
"""
|
||||
# open test queries file
|
||||
current_dir = Path(__file__).parent
|
||||
test_queries_path = current_dir / "test_queries.json"
|
||||
logger.info(f"Loading test queries from {test_queries_path}")
|
||||
if not test_queries_path.exists():
|
||||
raise FileNotFoundError(f"Test queries file not found at {test_queries_path}")
|
||||
with test_queries_path.open("r") as f:
|
||||
test_queries_raw: list[dict] = json.load(f)
|
||||
|
||||
# setup llm for question_search generation
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
search_tool = SearchToolOverride()
|
||||
|
||||
tool_call_supported = explicit_tool_calling_supported(
|
||||
llm.config.model_provider, llm.config.model_name
|
||||
)
|
||||
|
||||
# validate keys and generate question_search if missing
|
||||
test_queries: list[TestQuery] = []
|
||||
for query_raw in test_queries_raw:
|
||||
try:
|
||||
test_query = TestQuery(**query_raw)
|
||||
except ValidationError as e:
|
||||
logger.error(f"Incorrectly formatted query: {e}")
|
||||
continue
|
||||
|
||||
if test_query.question_search is None:
|
||||
test_query.question_search = _modify_one_query(
|
||||
query=test_query.question,
|
||||
llm=llm,
|
||||
prompt_config=prompt_config,
|
||||
tool=search_tool,
|
||||
tool_call_supported=tool_call_supported,
|
||||
)
|
||||
test_queries.append(test_query)
|
||||
|
||||
return test_queries
|
||||
|
||||
|
||||
def export_test_queries(test_queries: list[TestQuery], export_path: Path) -> None:
|
||||
"""Exports the test queries to a JSON file."""
|
||||
logger.info(f"Exporting test queries to {export_path}")
|
||||
with export_path.open("w") as f:
|
||||
json.dump(
|
||||
[query.model_dump() for query in test_queries],
|
||||
f,
|
||||
indent=4,
|
||||
)
|
||||
|
||||
|
||||
class SearchToolOverride(SearchTool):
|
||||
def __init__(self) -> None:
|
||||
# do nothing, only class variables are required for the functions we call
|
||||
pass
|
||||
|
||||
|
||||
warned = False
|
||||
|
||||
|
||||
def _modify_one_query(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
prompt_config: PromptConfig,
|
||||
tool: SearchTool,
|
||||
tool_call_supported: bool,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
) -> str:
|
||||
global warned
|
||||
if not warned:
|
||||
logger.warning(
|
||||
"Generating question_search. If you do not save the question_search, "
|
||||
"it will be generated again on the next run, potentially altering the search results."
|
||||
)
|
||||
warned = True
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=query,
|
||||
prompt_config=prompt_config,
|
||||
files=[],
|
||||
single_message_history=None,
|
||||
),
|
||||
system_message=default_build_system_message(prompt_config, llm.config),
|
||||
message_history=[],
|
||||
llm_config=llm.config,
|
||||
raw_user_query=query,
|
||||
raw_user_uploaded_files=[],
|
||||
single_message_history=None,
|
||||
)
|
||||
|
||||
if tool_call_supported:
|
||||
prompt = prompt_builder.build()
|
||||
tool_definition = tool.tool_definition()
|
||||
stream = llm.stream(
|
||||
prompt=prompt,
|
||||
tools=[tool_definition],
|
||||
tool_choice="required",
|
||||
structured_response_format=None,
|
||||
)
|
||||
tool_message = process_llm_stream(
|
||||
messages=stream,
|
||||
should_stream_answer=False,
|
||||
writer=writer,
|
||||
)
|
||||
return (
|
||||
tool_message.tool_calls[0]["args"]["query"]
|
||||
if tool_message.tool_calls
|
||||
else query
|
||||
)
|
||||
|
||||
history = prompt_builder.get_message_history()
|
||||
return cast(
|
||||
dict[str, str],
|
||||
tool.get_args_for_non_tool_calling_llm(
|
||||
query=query,
|
||||
history=history,
|
||||
llm=llm,
|
||||
force_run=True,
|
||||
),
|
||||
)["query"]
|
||||
@@ -1,94 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Document
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.regression.search_quality.util_retrieve import group_by_documents
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
class Metrics(BaseModel):
|
||||
# computed if ground truth is provided
|
||||
ground_truth_ratio_topk: Optional[float] = None
|
||||
ground_truth_avg_rank_delta: Optional[float] = None
|
||||
|
||||
# computed if reranked results are provided
|
||||
soft_truth_ratio_topk: Optional[float] = None
|
||||
soft_truth_avg_rank_delta: Optional[float] = None
|
||||
|
||||
|
||||
metric_names = list(Metrics.model_fields.keys())
|
||||
|
||||
|
||||
def get_corresponding_document(
|
||||
doc_link: str, db_session: Session
|
||||
) -> Optional[Document]:
|
||||
"""Get the corresponding document from the database."""
|
||||
doc_filter = db_session.query(Document).filter(Document.link == doc_link)
|
||||
count = doc_filter.count()
|
||||
if count == 0:
|
||||
logger.warning(f"Could not find document with link {doc_link}, ignoring")
|
||||
return None
|
||||
if count > 1:
|
||||
logger.warning(f"Found multiple documents with link {doc_link}, using first")
|
||||
return doc_filter.first()
|
||||
|
||||
|
||||
def evaluate_one_query(
|
||||
search_chunks: list[InferenceChunk],
|
||||
rerank_chunks: list[InferenceChunk],
|
||||
true_documents: list[Document],
|
||||
topk: int,
|
||||
) -> Metrics:
|
||||
"""Computes metrics for the search results, relative to the ground truth and reranked results."""
|
||||
metrics_dict: dict[str, float] = {}
|
||||
|
||||
search_documents = group_by_documents(search_chunks)
|
||||
search_ranks = {docid: rank for rank, docid in enumerate(search_documents)}
|
||||
search_ranks_topk = {
|
||||
docid: rank for rank, docid in enumerate(search_documents[:topk])
|
||||
}
|
||||
true_ranks = {doc.id: rank for rank, doc in enumerate(true_documents)}
|
||||
|
||||
if true_documents:
|
||||
metrics_dict["ground_truth_ratio_topk"] = _compute_ratio(
|
||||
search_ranks_topk, true_ranks
|
||||
)
|
||||
metrics_dict["ground_truth_avg_rank_delta"] = _compute_avg_rank_delta(
|
||||
search_ranks, true_ranks
|
||||
)
|
||||
|
||||
if rerank_chunks:
|
||||
# build soft truth out of ground truth + reranked results, up to topk
|
||||
soft_ranks = true_ranks
|
||||
for docid in group_by_documents(rerank_chunks):
|
||||
if len(soft_ranks) >= topk:
|
||||
break
|
||||
if docid not in soft_ranks:
|
||||
soft_ranks[docid] = len(soft_ranks)
|
||||
|
||||
metrics_dict["soft_truth_ratio_topk"] = _compute_ratio(
|
||||
search_ranks_topk, soft_ranks
|
||||
)
|
||||
metrics_dict["soft_truth_avg_rank_delta"] = _compute_avg_rank_delta(
|
||||
search_ranks, soft_ranks
|
||||
)
|
||||
|
||||
return Metrics(**metrics_dict)
|
||||
|
||||
|
||||
def _compute_ratio(search_ranks: dict[str, int], true_ranks: dict[str, int]) -> float:
|
||||
return len(set(search_ranks) & set(true_ranks)) / len(true_ranks)
|
||||
|
||||
|
||||
def _compute_avg_rank_delta(
|
||||
search_ranks: dict[str, int], true_ranks: dict[str, int]
|
||||
) -> float:
|
||||
out = len(search_ranks)
|
||||
return sum(
|
||||
abs(search_ranks.get(docid, out) - rank) for docid, rank in true_ranks.items()
|
||||
) / len(true_ranks)
|
||||
@@ -1,88 +0,0 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.postprocessing.postprocessing import semantic_reranking
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.context.search.utils import remove_stop_words_and_punctuation
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.regression.search_quality.util_config import SearchEvalConfig
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def search_one_query(
|
||||
question_search: str,
|
||||
multilingual_expansion: list[str],
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
config: SearchEvalConfig,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Uses the search pipeline to retrieve relevant chunks for the given query."""
|
||||
# the retrieval preprocessing is fairly stripped down so the query doesn't unexpectedly change
|
||||
query_embedding = get_query_embedding(question_search, db_session)
|
||||
|
||||
all_query_terms = question_search.split()
|
||||
processed_keywords = (
|
||||
remove_stop_words_and_punctuation(all_query_terms)
|
||||
if not multilingual_expansion
|
||||
else all_query_terms
|
||||
)
|
||||
|
||||
is_keyword = query_analysis(question_search)[0]
|
||||
hybrid_alpha = config.hybrid_alpha_keyword if is_keyword else config.hybrid_alpha
|
||||
|
||||
access_control_list = ["PUBLIC"]
|
||||
if config.user_email:
|
||||
access_control_list.append(f"user_email:{config.user_email}")
|
||||
filters = IndexFilters(
|
||||
tags=[],
|
||||
user_file_ids=[],
|
||||
user_folder_ids=[],
|
||||
access_control_list=access_control_list,
|
||||
tenant_id=None,
|
||||
)
|
||||
|
||||
results = document_index.hybrid_retrieval(
|
||||
query=question_search,
|
||||
query_embedding=query_embedding,
|
||||
final_keywords=processed_keywords,
|
||||
filters=filters,
|
||||
hybrid_alpha=hybrid_alpha,
|
||||
time_decay_multiplier=config.doc_time_decay,
|
||||
num_to_retrieve=config.num_returned_hits,
|
||||
ranking_profile_type=config.rank_profile,
|
||||
offset=config.offset,
|
||||
title_content_ratio=config.title_content_ratio,
|
||||
)
|
||||
|
||||
return [result.to_inference_chunk() for result in results]
|
||||
|
||||
|
||||
def rerank_one_query(
|
||||
question: str,
|
||||
retrieved_chunks: list[InferenceChunk],
|
||||
rerank_settings: RerankingDetails,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Uses the reranker to rerank the retrieved chunks for the given query."""
|
||||
rerank_settings.num_rerank = len(retrieved_chunks)
|
||||
return semantic_reranking(
|
||||
query_str=question,
|
||||
rerank_settings=rerank_settings,
|
||||
chunks=retrieved_chunks,
|
||||
rerank_metrics_callback=None,
|
||||
)[0]
|
||||
|
||||
|
||||
def group_by_documents(chunks: list[InferenceChunk]) -> list[str]:
|
||||
"""Groups a sorted list of chunks into a sorted list of document ids."""
|
||||
seen_docids: set[str] = set()
|
||||
retrieved_docids: list[str] = []
|
||||
for chunk in chunks:
|
||||
if chunk.document_id not in seen_docids:
|
||||
seen_docids.add(chunk.document_id)
|
||||
retrieved_docids.append(chunk.document_id)
|
||||
return retrieved_docids
|
||||
78
backend/tests/regression/search_quality/utils.py
Normal file
78
backend/tests/regression/search_quality/utils.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from ragas import evaluate # type: ignore
|
||||
from ragas import EvaluationDataset # type: ignore
|
||||
from ragas import SingleTurnSample # type: ignore
|
||||
from ragas.dataset_schema import EvaluationResult # type: ignore
|
||||
from ragas.metrics import Faithfulness # type: ignore
|
||||
from ragas.metrics import ResponseGroundedness # type: ignore
|
||||
from ragas.metrics import ResponseRelevancy # type: ignore
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.db.models import Document
|
||||
from onyx.prompts.prompt_utils import build_doc_context_str
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.regression.search_quality.models import GroundTruth
|
||||
from tests.regression.search_quality.models import RetrievedDocument
|
||||
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
def find_document(ground_truth: GroundTruth, db_session: Session) -> Document | None:
|
||||
"""Find a document by its link."""
|
||||
# necessary preprocessing of links
|
||||
doc_link = ground_truth.doc_link
|
||||
if ground_truth.doc_source == DocumentSource.GOOGLE_DRIVE:
|
||||
if "/edit" in doc_link:
|
||||
doc_link = doc_link.split("/edit", 1)[0]
|
||||
elif "/view" in doc_link:
|
||||
doc_link = doc_link.split("/view", 1)[0]
|
||||
elif ground_truth.doc_source == DocumentSource.FIREFLIES:
|
||||
doc_link = doc_link.split("?", 1)[0]
|
||||
|
||||
docs = db_session.query(Document).filter(Document.link.ilike(f"{doc_link}%")).all()
|
||||
if len(docs) == 0:
|
||||
logger.warning("Could not find ground truth document: %s", doc_link)
|
||||
return None
|
||||
elif len(docs) > 1:
|
||||
logger.warning(
|
||||
"Found multiple ground truth documents: %s, using the first one: %s",
|
||||
doc_link,
|
||||
docs[0].id,
|
||||
)
|
||||
return docs[0]
|
||||
|
||||
|
||||
def search_docs_to_doc_contexts(docs: list[SavedSearchDoc]) -> list[RetrievedDocument]:
|
||||
return [
|
||||
RetrievedDocument(
|
||||
document_id=doc.document_id,
|
||||
content=build_doc_context_str(
|
||||
semantic_identifier=doc.semantic_identifier,
|
||||
source_type=doc.source_type,
|
||||
content=doc.blurb, # getting the full content is painful
|
||||
metadata_dict=doc.metadata,
|
||||
updated_at=doc.updated_at,
|
||||
ind=ind,
|
||||
include_metadata=True,
|
||||
),
|
||||
)
|
||||
for ind, doc in enumerate(docs)
|
||||
]
|
||||
|
||||
|
||||
def ragas_evaluate(question: str, answer: str, contexts: list[str]) -> EvaluationResult:
|
||||
sample = SingleTurnSample(
|
||||
user_input=question,
|
||||
retrieved_contexts=contexts,
|
||||
response=answer,
|
||||
)
|
||||
dataset = EvaluationDataset([sample])
|
||||
return evaluate(
|
||||
dataset,
|
||||
metrics=[
|
||||
ResponseRelevancy(),
|
||||
ResponseGroundedness(),
|
||||
Faithfulness(),
|
||||
],
|
||||
)
|
||||
Reference in New Issue
Block a user