Compare commits

...

26 Commits

Author SHA1 Message Date
Raunak Bhagat
1f91fd51ee Saving changes 2025-07-02 09:27:00 -07:00
Raunak Bhagat
3accbd2c11 Merge branch 'main' into ronnie/hackathon 2025-07-01 18:15:49 -07:00
Raunak Bhagat
4a8348a169 Merge remote-tracking branch 'origin/ronnie/hackathon' into ronnie/hackathon 2025-07-01 17:49:50 -07:00
Raunak Bhagat
bdadef00cd Revert search logic 2025-07-01 16:32:56 -07:00
Raunak Bhagat
318b5c0c6a Minor edits 2025-07-01 14:23:39 -07:00
Raunak Bhagat
6cfdedb9d9 Merge remote-tracking branch 'origin/search-quality-tests-revamp' into ronnie/hackathon 2025-07-01 13:45:06 -07:00
Raunak Bhagat
7782a735c2 Saving minor changes 2025-07-01 13:37:38 -07:00
Rei Meguro
060f097737 mypy ragas 2025-07-01 12:54:34 -07:00
Rei Meguro
a294f4464c more error messages 2025-07-01 12:52:17 -07:00
Rei Meguro
7f6eb68295 feat: answer eval 2025-07-01 12:43:26 -07:00
Rei Meguro
4165150b20 document context + skip genai fix 2025-07-01 10:54:40 -07:00
Rei Meguro
4f9968727c refactor: config refactor 2025-07-01 10:03:34 -07:00
Rei Meguro
833bac077b nits 2025-07-01 01:33:20 -07:00
Rei Meguro
f06ae0a340 search quality tests improvement
Co-authored-by: wenxi-onyx <wenxi@onyx.app>
2025-07-01 01:11:36 -07:00
Raunak Bhagat
b082e845fd Fix type errors 2025-06-30 16:00:30 -07:00
Raunak Bhagat
81148e10e2 Fix example for running basic-graph (+ minor formatting nits) 2025-06-30 16:00:30 -07:00
Raunak Bhagat
8b5ff48271 Add hackathon folder to gitignore 2025-06-30 16:00:30 -07:00
joachim-danswer
013bed3157 fix 2025-06-30 15:19:42 -07:00
joachim-danswer
289f27c43a updates 2025-06-30 15:06:12 -07:00
Raunak Bhagat
6a78343bf3 Fix example for running basic-graph (+ minor formatting nits) 2025-06-30 14:56:25 -07:00
Raunak Bhagat
655a5f3d56 Add hackathon folder to gitignore 2025-06-30 14:56:02 -07:00
joachim-danswer
736a9bd332 erase history 2025-06-30 09:01:23 -07:00
joachim-danswer
8bcad415bb nit 2025-06-30 08:16:43 -07:00
joachim-danswer
93e6e4a089 mypy nits 2025-06-30 07:49:55 -07:00
joachim-danswer
ed0062dce0 fix 2025-06-30 02:45:03 -07:00
joachim-danswer
6e8bf3120c hackathon v1 changes 2025-06-30 01:39:36 -07:00
23 changed files with 1461 additions and 638 deletions

3
.gitignore vendored
View File

@@ -26,3 +26,6 @@ jira_test_env
/deployment/data/nginx/app.conf
*.sw?
/backend/tests/regression/answer_quality/search_test_config.yaml
# hackathon
hackathon

View File

@@ -82,13 +82,16 @@ if __name__ == "__main__":
from onyx.context.search.models import SearchRequest
from onyx.llm.factory import get_default_llms
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.db.engine.sql_engine import SqlEngine
SqlEngine.init_engine(pool_size=10, max_overflow=0)
graph = basic_graph_builder()
compiled_graph = graph.compile()
input = BasicInput(unused=True)
primary_llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
config, _ = get_test_config(
config = get_test_config(
db_session=db_session,
primary_llm=primary_llm,
fast_llm=fast_llm,

View File

@@ -51,6 +51,7 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)

View File

@@ -1,7 +1,12 @@
import csv
import json
import os
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import GraphConfig
@@ -11,6 +16,9 @@ from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_agent_search_graph
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import (
run_basic_graph as run_hackathon_graph,
) # You can create your own graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.chat.models import AgentAnswerPiece
@@ -22,9 +30,11 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.models import ToolCallFinalResult
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import RerankingDetails
@@ -44,6 +54,191 @@ logger = setup_logger()
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
"""
Calculate the score for a given position.
"""
if pos > max_acceptable_pos:
return 0
elif pos == 1:
return 1
elif pos == 2:
return 0.8
else:
return 4 / (pos + 5)
def _clean_doc_id_link(doc_link: str) -> str:
"""
Clean the google doc link.
"""
if "google.com" in doc_link:
if "/edit" in doc_link:
return "/edit".join(doc_link.split("/edit")[:-1])
elif "/view" in doc_link:
return "/view".join(doc_link.split("/view")[:-1])
else:
return doc_link
if "app.fireflies.ai" in doc_link:
return "?".join(doc_link.split("?")[:-1])
return doc_link
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
"""
Get the score of a document from the document results.
"""
match_pos = None
for pos, comp_doc in enumerate(doc_results, start=1):
clear_doc_id = _clean_doc_id_link(doc_id)
clear_comp_doc = _clean_doc_id_link(comp_doc)
if clear_doc_id == clear_comp_doc:
match_pos = pos
if match_pos is None:
return 0.0
return _calc_score_for_pos(match_pos)
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH) -> None:
"""
Append an empty line to the CSV file.
"""
_append_answer_to_csv("", "", csv_path)
def _append_ground_truth_to_csv(
query: str,
ground_truth_docs: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for doc_id in ground_truth_docs:
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
logger.debug("Appended score to csv file")
def _append_score_to_csv(
query: str,
score: float,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", "", score])
logger.debug("Appended score to csv file")
def _append_search_results_to_csv(
query: str,
doc_results: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the search results to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for pos, doc in enumerate(doc_results, start=1):
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
logger.debug("Appended search results to csv file")
def _append_answer_to_csv(
query: str,
answer: str,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", answer, ""])
logger.debug("Appended answer to csv file")
class Answer:
def __init__(
self,
@@ -134,6 +329,8 @@ class Answer:
@property
def processed_streamed_output(self) -> AnswerStream:
_HACKATHON_TEST_EXECUTION = False
if self._processed_stream is not None:
yield from self._processed_stream
return
@@ -154,22 +351,120 @@ class Answer:
)
):
run_langgraph = run_dc_graph
elif (
self.graph_config.inputs.persona
and self.graph_config.inputs.persona.description.startswith(
"Hackathon Test"
)
):
_HACKATHON_TEST_EXECUTION = True
run_langgraph = run_hackathon_graph
else:
run_langgraph = run_basic_graph
stream = run_langgraph(
self.graph_config,
)
if _HACKATHON_TEST_EXECUTION:
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
if input_data.startswith("["):
input_type = "json"
input_list = json.loads(input_data)
else:
input_type = "list"
input_list = input_data.split(";")
num_examples_with_ground_truth = 0
total_score = 0.0
question = None
for question_num, question_data in enumerate(input_list):
ground_truth_docs = None
if input_type == "json":
question = question_data["question"]
ground_truth = question_data.get("ground_truth")
if ground_truth:
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
logger.info(f"Question {question_num}: {question}")
_append_ground_truth_to_csv(question, ground_truth_docs)
else:
continue
else:
question = question_data
self.graph_config.inputs.prompt_builder.raw_user_query = question
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
HumanMessage(
content=question, additional_kwargs={}, response_metadata={}
),
2,
)
self.graph_config.tooling.force_use_tool.force_use = True
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
llm_answer_segments: list[str] = []
doc_results: list[str] | None = None
for answer_piece in processed_stream:
if isinstance(answer_piece, OnyxAnswerPiece):
llm_answer_segments.append(answer_piece.answer_piece or "")
elif isinstance(answer_piece, ToolCallFinalResult):
doc_results = [x.get("link") for x in answer_piece.tool_result]
if doc_results:
_append_search_results_to_csv(question, doc_results)
_append_answer_to_csv(question, "".join(llm_answer_segments))
if ground_truth_docs and doc_results:
num_examples_with_ground_truth += 1
doc_score = 0.0
for doc_id in ground_truth_docs:
doc_score += _get_doc_score(doc_id, doc_results)
_append_score_to_csv(question, doc_score)
total_score += doc_score
self._processed_stream = processed_stream
if num_examples_with_ground_truth > 0:
comprehensive_score = total_score / num_examples_with_ground_truth
else:
comprehensive_score = 0
_append_empty_line()
if not question:
raise RuntimeError(f"No questions in input list; {input_list=}")
_append_score_to_csv(question, comprehensive_score)
else:
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
self._processed_stream = processed_stream
@property
def llm_answer(self) -> str:

View File

@@ -1012,6 +1012,7 @@ def stream_chat_message_objects(
tools=tools,
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(

View File

@@ -187,8 +187,12 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
if (
self.new_messages_and_token_cnts
and isinstance(self.user_message_and_token_cnt[0].content, str)
and self.user_message_and_token_cnt[0].content.startswith("Refer")
):
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -787,3 +787,7 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
# Forcing Vespa Language
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
)

View File

@@ -483,3 +483,24 @@ def section_relevance_list_impl(
items=final_context_sections,
)
return [ind in llm_indices for ind in range(len(final_context_sections))]
if __name__ == "__main__":
from onyx.db.engine.sql_engine import SqlEngine, get_session_with_current_tenant
from onyx.llm.factory import get_default_llms
SqlEngine.init_engine(pool_size=10, max_overflow=10)
llm, fast_llm = get_default_llms()
with get_session_with_current_tenant() as db_session:
sp = SearchPipeline(
search_request=SearchRequest(
query="What is Onyx?",
),
user=None,
llm=llm,
fast_llm=fast_llm,
skip_query_analysis=True,
db_session=db_session,
)

View File

@@ -1,5 +1,6 @@
import string
from collections.abc import Callable
from enum import Enum
import nltk # type:ignore
from sqlalchemy.orm import Session
@@ -20,6 +21,7 @@ from onyx.context.search.preprocessing.preprocessing import HYBRID_ALPHA_KEYWORD
from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_multilingual_expansion
from onyx.db.search_settings import get_simple_expansion_settings
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.document_index.vespa.shared_utils.utils import (
@@ -27,6 +29,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.query_expansion import multilingual_query_expansion
from onyx.secondary_llm_flows.query_expansion import simple_unilingual_query_expansion
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import run_in_background
@@ -41,6 +44,33 @@ from shared_configs.model_server_models import Embedding
logger = setup_logger()
class ExpansionType(Enum):
Off = "off"
Simple = "simple"
Multilingual = "multilingual"
def _get_expansion_type(
db_session: Session,
query: str,
) -> tuple[ExpansionType, list[str], int]:
if "\n" in query or "\r" in query:
return ExpansionType.Off, [], 0
multilingual_expansion = get_multilingual_expansion(db_session)
if multilingual_expansion:
return ExpansionType.Multilingual, multilingual_expansion, 0
enable_simple_expansion, simple_expansion_count = get_simple_expansion_settings(
db_session
)
if enable_simple_expansion:
return ExpansionType.Simple, [], simple_expansion_count
return ExpansionType.Off, [], 0
def _dedupe_chunks(
chunks: list[InferenceChunkUncleaned],
) -> list[InferenceChunkUncleaned]:
@@ -348,6 +378,92 @@ def _simplify_text(text: str) -> str:
).lower()
def _retrieve_search_with_simple_expansion(
db_session: Session,
document_index: DocumentIndex,
query: SearchQuery,
simple_expansion_count: int,
) -> list[InferenceChunk]:
# Use simple unilingual expansion for basic cases
simplified_queries_simple = set()
run_queries_simple: list[tuple[Callable, tuple]] = []
query_rephrases_simple = simple_unilingual_query_expansion(
query=query.query,
num_expansions=simple_expansion_count,
)
# Just to be extra sure, add the original query.
query_rephrases_simple.append(query.query)
for rephrase in set(query_rephrases_simple):
# Sometimes the model rephrases the query in the same language with minor changes
# Avoid doing an extra search with the minor changes as this biases the results
simplified_rephrase = _simplify_text(rephrase)
if simplified_rephrase in simplified_queries_simple:
continue
simplified_queries_simple.add(simplified_rephrase)
q_copy = query.model_copy(
update={
"query": rephrase,
# need to recompute for each rephrase
# note that `SearchQuery` is a frozen model, so we can't update
# it below
"precomputed_query_embedding": None,
},
deep=True,
)
run_queries_simple.append(
(
doc_index_retrieval,
(q_copy, document_index, db_session),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries_simple)
return combine_retrieval_results(parallel_search_results)
def _retrieve_search_with_multilingual_expansion(
db_session: Session,
document_index: DocumentIndex,
query: SearchQuery,
multilingual_expansion: list[str],
) -> list[InferenceChunk]:
# Use multilingual expansion if configured
simplified_queries = set()
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(query.query, multilingual_expansion)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)
for rephrase in set(query_rephrases):
# Sometimes the model rephrases the query in the same language with minor changes
# Avoid doing an extra search with the minor changes as this biases the results
simplified_rephrase = _simplify_text(rephrase)
if simplified_rephrase in simplified_queries:
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.model_copy(
update={
"query": rephrase,
# need to recompute for each rephrase
# note that `SearchQuery` is a frozen model, so we can't update
# it below
"precomputed_query_embedding": None,
},
deep=True,
)
run_queries.append(
(
doc_index_retrieval,
(q_copy, document_index, db_session),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
return combine_retrieval_results(parallel_search_results)
def retrieve_chunks(
query: SearchQuery,
document_index: DocumentIndex,
@@ -358,48 +474,33 @@ def retrieve_chunks(
) -> list[InferenceChunk]:
"""Returns a list of the best chunks from an initial keyword/semantic/ hybrid search."""
multilingual_expansion = get_multilingual_expansion(db_session)
# Don't do query expansion on complex queries, rephrasings likely would not work well
if not multilingual_expansion or "\n" in query.query or "\r" in query.query:
expansion_type, multilingual_expansion, simple_expansion_count = (
_get_expansion_type(
db_session=db_session,
query=query.query,
)
)
if expansion_type == ExpansionType.Off:
top_chunks = doc_index_retrieval(
query=query, document_index=document_index, db_session=db_session
)
else:
simplified_queries = set()
run_queries: list[tuple[Callable, tuple]] = []
# Currently only uses query expansion on multilingual use cases
query_rephrases = multilingual_query_expansion(
query.query, multilingual_expansion
elif expansion_type == ExpansionType.Simple:
top_chunks = _retrieve_search_with_simple_expansion(
db_session=db_session,
document_index=document_index,
query=query,
simple_expansion_count=simple_expansion_count,
)
# Just to be extra sure, add the original query.
query_rephrases.append(query.query)
for rephrase in set(query_rephrases):
# Sometimes the model rephrases the query in the same language with minor changes
# Avoid doing an extra search with the minor changes as this biases the results
simplified_rephrase = _simplify_text(rephrase)
if simplified_rephrase in simplified_queries:
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.model_copy(
update={
"query": rephrase,
# need to recompute for each rephrase
# note that `SearchQuery` is a frozen model, so we can't update
# it below
"precomputed_query_embedding": None,
},
deep=True,
)
run_queries.append(
(
doc_index_retrieval,
(q_copy, document_index, db_session),
)
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
top_chunks = combine_retrieval_results(parallel_search_results)
elif expansion_type == ExpansionType.Multilingual:
top_chunks = _retrieve_search_with_multilingual_expansion(
db_session=db_session,
document_index=document_index,
query=query,
multilingual_expansion=multilingual_expansion,
)
else:
raise RuntimeError
if not top_chunks:
logger.warning(

View File

@@ -56,6 +56,8 @@ def create_search_settings(
contextual_rag_llm_name=search_settings.contextual_rag_llm_name,
contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider,
multilingual_expansion=search_settings.multilingual_expansion,
enable_simple_query_expansion=search_settings.enable_simple_query_expansion,
simple_expansion_count=search_settings.simple_expansion_count,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
@@ -195,6 +197,26 @@ def get_multilingual_expansion(db_session: Session | None = None) -> list[str]:
return search_settings.multilingual_expansion
def get_simple_expansion_settings(
db_session: Session | None = None,
) -> tuple[bool, int]:
"""Get simple query expansion settings from the database.
Returns:
tuple: (enabled, expansion_count)
"""
# if db_session is None:
# with get_session_with_current_tenant() as db_session:
# search_settings = get_current_search_settings(db_session)
# else:
# search_settings = get_current_search_settings(db_session)
# if not search_settings:
# return False, 2
# return search_settings.enable_simple_query_expansion, search_settings.simple_expansion_count
return True, 3
def update_search_settings(
current_settings: SearchSettings,
updated_settings: SavedSearchSettings,

View File

@@ -9,6 +9,25 @@ Query:
{query}
""".strip()
SIMPLE_QUERY_EXPANSION_PROMPT = """
Generate {num_expansions} alternative phrasings of the following query in the same language.
The goal is to express the same intent in different ways that might retrieve different relevant documents.
Guidelines:
- Keep the same meaning and intent as the original query
- Use different vocabulary and sentence structures
- Make each expansion distinct from the others
- Focus on different aspects or angles of the query
- Use synonyms, different verb forms, or alternative phrasings
Original query:
{query}
Respond with EXACTLY {num_expansions} alternative queries, one per line. Do not include any explanations or numbering.
Alternative queries:
""".strip()
SLACK_LANGUAGE_REPHRASE_PROMPT = """
As an AI assistant employed by an organization, \
your role is to transform user messages into concise \
@@ -27,3 +46,4 @@ Query:
# Use the following for easy viewing of prompts
if __name__ == "__main__":
print(LANGUAGE_REPHRASE_PROMPT)
print(SIMPLE_QUERY_EXPANSION_PROMPT)

View File

@@ -12,6 +12,7 @@ from onyx.llm.utils import dict_based_prompt_to_langchain_prompt
from onyx.llm.utils import message_to_string
from onyx.prompts.chat_prompts import HISTORY_QUERY_REPHRASE
from onyx.prompts.miscellaneous_prompts import LANGUAGE_REPHRASE_PROMPT
from onyx.prompts.miscellaneous_prompts import SIMPLE_QUERY_EXPANSION_PROMPT
from onyx.utils.logger import setup_logger
from onyx.utils.text_processing import count_punctuation
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -164,3 +165,47 @@ def thread_based_query_rephrase(
logger.debug(f"Rephrased combined query: {rephrased_query}")
return rephrased_query
def simple_unilingual_query_expansion(
query: str,
num_expansions: int = 2,
use_threads: bool = True,
) -> list[str]:
"""
Generate alternative phrasings of a query in the same language.
This is useful for improving search recall by trying different ways to express the same intent.
"""
def _get_simple_expansion_messages() -> list[dict[str, str]]:
messages = [
{
"role": "user",
"content": SIMPLE_QUERY_EXPANSION_PROMPT.format(
query=query, num_expansions=num_expansions
),
},
]
return messages
try:
_, fast_llm = get_default_llms(timeout=5)
except GenAIDisabledException:
logger.warning("Unable to perform simple query expansion, Gen AI disabled")
return [query]
messages = _get_simple_expansion_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(fast_llm.invoke(filled_llm_prompt))
logger.debug(f"Simple query expansion output: {model_output}")
# Parse the response - expect newline-separated queries
expanded_queries = [
line.strip() for line in model_output.split("\n") if line.strip()
]
# Add the original query to ensure we always have at least one query
if query not in expanded_queries:
expanded_queries.append(query)
return expanded_queries[: num_expansions + 1] # +1 to account for original query

View File

@@ -1,7 +1,10 @@
import copy
import csv
import json
import os
from collections.abc import Callable
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import cast
from typing import TypeVar
@@ -19,6 +22,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prune_and_merge import prune_and_merge_sections
from onyx.chat.prune_and_merge import prune_sections
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -62,6 +66,39 @@ SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
def _append_ranking_stats_to_csv(
llm_doc_results: list[tuple[int, str]],
query: str,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for pos, doc in llm_doc_results:
writer.writerow([query, pos, doc, ""])
logger.debug(f"Appended {len(llm_doc_results)} ranking stats to {csv_path}")
class SearchResponseSummary(SearchQueryInfo):
top_sections: list[InferenceSection]
rephrased_query: str | None = None
@@ -502,6 +539,12 @@ def yield_search_responses(
)
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
# Append ranking statistics to a CSV file
llm_doc_results = []
for pos, doc in enumerate(llm_docs):
llm_doc_results.append((pos, doc.document_id))
# _append_ranking_stats_to_csv(llm_doc_results, query)
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

View File

@@ -4,16 +4,12 @@ This Python script evaluates the search results for a list of queries.
This script will likely get refactored in the future as an API endpoint.
In the meanwhile, it is used to evaluate the search quality using locally ingested documents.
The key differentiating factor with `answer_quality` is that it can evaluate results without explicit "ground truth" using the reranker as a reference.
## Usage
1. Ensure you have the required dependencies installed and onyx running.
1. Ensure you have the required dependencies installed and onyx running. Note that auth must be disabled for this script to work (`AUTH_TYPE=disabled`, which is the case by default).
2. Ensure a reranker model is configured in the search settings.
This can be checked/modified by opening the admin panel, going to search settings, and ensuring a reranking model is set.
3. Set up the PYTHONPATH permanently:
2. Set up the PYTHONPATH permanently:
Add the following line to your shell configuration file (e.g., `~/.bashrc`, `~/.zshrc`, or `~/.bash_profile`):
```
export PYTHONPATH=$PYTHONPATH:/path/to/onyx/backend
@@ -21,42 +17,41 @@ This can be checked/modified by opening the admin panel, going to search setting
Replace `/path/to/onyx` with the actual path to your Onyx repository.
After adding this line, restart your terminal or run `source ~/.bashrc` (or the appropriate config file) to apply the changes.
4. Navigate to Onyx repo, search_quality folder:
3. Navigate to Onyx repo, **search_quality** folder:
```
cd path/to/onyx/backend/tests/regression/search_quality
```
5. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The possible fields are:
4. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The possible fields are:
- `question: str` the query
- `question_search: Optional[str]` modified query specifically for the search step
- `ground_truth: Optional[list[GroundTruth]]` a ranked list of expected search results with fields:
- `doc_source: str` document source (e.g., Web, Drive, Linear), currently unused
- `doc_source: str` document source (e.g., web, google_drive, linear), used to normalize links in some cases
- `doc_link: str` link associated with document, used to find corresponding document in local index
- `categories: Optional[list[str]]` list of categories, used to aggregate evaluation results
6. Copy `search_eval_config.yaml.template` to `search_eval_config.yaml` and specify the search and eval parameters
7. Run `run_search_eval.py` to run the search and evaluate the search results
5. Run `run_search_eval.py` to evaluate the queries. All parameters are optional and have sensible defaults:
```
python run_search_eval.py
-d --dataset # Path to the test-set JSON file (default: ./test_queries.json)
-n --num_search # Maximum number of search results to check per query (default: 50)
-a --num_answer # Maximum number of search results to use for answer evaluation (default: 25)
-w --workers # Number of parallel search requests (default: 10)
-q --timeout # Request timeout in seconds (default: 120)
-e --api_endpoint # Base URL of the Onyx API server (default: http://127.0.0.1:8080)
-s --search_only # Only perform search and not answer evaluation (default: false)
-r --rerank_all # Always rerank all search results (default: false)
-t --tenant_id # Tenant ID to use for the evaluation (default: None)
```
8. Optionally, save the generated `test_queries.json` in the export folder to reuse the generated `question_search`, and rerun the search evaluation with alternative search parameters.
6. After the run an `eval-YYYY-MM-DD-HH-MM-SS` folder is created containing:
## Metrics
There are two main metrics currently implemented:
- ratio_topk: the ratio of documents in the comparison set that are in the topk search results (higher is better, 0-1)
- avg_rank_delta: the average rank difference between the comparison set and search results (lower is better, 0-inf)
* `test_queries.json` the dataset used with the indexed ground truth documents.
* `search_results.json` per-query details.
* `results_by_category.csv` aggregated metrics per category and for "all".
* `search_position_chart.png` bar-chart of ground-truth ranks.
Ratio topk gives a general idea on whether the most relevant documents are appearing first in the search results. Decreasing `eval_topk` will make this metric stricter, requiring relevant documents to appear in a narrow window.
Avg rank delta is another metric which can give insight on the performance of documents not in the topk search results. If none of the comparison documents are in the topk, `ratio_topk` will only show a 0, whereas `avg_rank_delta` will show a higher value the worse the search results gets.
Furthermore, there are two versions of the metrics: ground truth, and soft truth.
The ground truth includes documents explicitly listed as relevant in the test dataset. The ground truth metrics will only be computed if a ground truth set is provided for the question and exists in the index.
The soft truth is built on top of the ground truth (if provided), filling the remaining entries with results from the reranker. The soft truth metrics will only be computed if `skip_rerank` is false. Computing the soft truth metric can be extremely slow, especially for large `num_returned_hits`. However, it can provide a good basis when there are many relevant documents in no particular order, or for running quick tests without explicitly having to mention which documents are relevant.
You can copy the generated `test_queries.json` back to the root folder to skip the ground truth documents search step.

View File

@@ -0,0 +1,75 @@
from pydantic import BaseModel
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import SavedSearchDoc
class GroundTruth(BaseModel):
doc_source: DocumentSource
doc_link: str
class TestQuery(BaseModel):
question: str
ground_truth: list[GroundTruth] = []
categories: list[str] = []
# autogenerated
ground_truth_docids: list[str] = []
class EvalConfig(BaseModel):
max_search_results: int
max_answer_context: int
num_workers: int
request_timeout: int
api_url: str
search_only: bool
rerank_all: bool
class OneshotQAResult(BaseModel):
time_taken: float
top_documents: list[SavedSearchDoc]
answer: str | None
class RetrievedDocument(BaseModel):
document_id: str
content: str
class AnalysisSummary(BaseModel):
question: str
categories: list[str]
found: bool
rank: int | None
total_results: int
ground_truth_count: int
answer: str | None = None
response_relevancy: float | None = None
response_groundedness: float | None = None
faithfulness: float | None = None
retrieved: list[RetrievedDocument] = []
time_taken: float | None = None
class SearchMetrics(BaseModel):
total_queries: int
found_count: int
# for found results
best_rank: int
worst_rank: int
average_rank: float
top_k_accuracy: dict[int, float]
class AnswerMetrics(BaseModel):
average_response_relevancy: float
average_response_groundedness: float
average_faithfulness: float
class CombinedMetrics(SearchMetrics, AnswerMetrics):
average_time_taken: float

View File

@@ -1,151 +1,692 @@
import csv
import json
import os
import time
from collections import defaultdict
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from pathlib import Path
import matplotlib.pyplot as plt
import requests
from pydantic import ValidationError
from requests.exceptions import RequestException
from retry import retry
from ee.onyx.server.query_and_chat.models import OneShotQARequest
from ee.onyx.server.query_and_chat.models import OneShotQAResponse
from onyx.chat.models import ThreadMessage
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW
from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE
from onyx.configs.constants import MessageType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import RerankingDetails
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.context.search.models import RetrievalDetails
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_multilingual_expansion
from onyx.document_index.factory import get_default_document_index
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
from tests.regression.search_quality.util_config import load_config
from tests.regression.search_quality.util_data import export_test_queries
from tests.regression.search_quality.util_data import load_test_queries
from tests.regression.search_quality.util_eval import evaluate_one_query
from tests.regression.search_quality.util_eval import get_corresponding_document
from tests.regression.search_quality.util_eval import metric_names
from tests.regression.search_quality.util_retrieve import rerank_one_query
from tests.regression.search_quality.util_retrieve import search_one_query
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
from tests.regression.search_quality.models import AnalysisSummary
from tests.regression.search_quality.models import CombinedMetrics
from tests.regression.search_quality.models import EvalConfig
from tests.regression.search_quality.models import OneshotQAResult
from tests.regression.search_quality.models import TestQuery
from tests.regression.search_quality.utils import find_document
from tests.regression.search_quality.utils import ragas_evaluate
from tests.regression.search_quality.utils import search_docs_to_doc_contexts
logger = setup_logger(__name__)
GENERAL_HEADERS = {"Content-Type": "application/json"}
TOP_K_LIST = [1, 3, 5, 10]
def run_search_eval() -> None:
config = load_config()
test_queries = load_test_queries()
# export related
export_path = Path(config.export_folder)
export_test_queries(test_queries, export_path / "test_queries.json")
search_result_path = export_path / "search_results.csv"
eval_path = export_path / "eval_results.csv"
aggregate_eval_path = export_path / "aggregate_eval.csv"
aggregate_results: dict[str, list[list[float]]] = defaultdict(
lambda: [[] for _ in metric_names]
)
class SearchAnswerAnalyzer:
def __init__(
self,
config: EvalConfig,
tenant_id: str | None = None,
):
if not MULTI_TENANT:
logger.info("Running in single-tenant mode")
tenant_id = POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
elif tenant_id is None:
raise ValueError("Tenant ID is required for multi-tenant")
with get_session_with_current_tenant() as db_session:
multilingual_expansion = get_multilingual_expansion(db_session)
search_settings = get_current_search_settings(db_session)
document_index = get_default_document_index(search_settings, None)
rerank_settings = RerankingDetails.from_db_model(search_settings)
# rerank_settings = RerankingDetails.from_db_model(search_settings)
self.config = config
self.tenant_id = tenant_id
self.results: list[AnalysisSummary] = []
self.stats: dict[str, list[AnalysisSummary]] = defaultdict(list)
self.metrics: dict[str, CombinedMetrics] = {}
if config.skip_rerank:
logger.warning("Reranking is disabled, evaluation will not run")
elif rerank_settings.rerank_model_name is None:
raise ValueError(
"Reranking is enabled but no reranker is configured. "
"Please set the reranker in the admin panel search settings."
)
# get search related settings
self._rerank_settings = self._get_rerank_settings()
# run search and evaluate
logger.info(
"Running search and evaluation... "
f"Individual search and evaluation results will be saved to {search_result_path} and {eval_path}"
def run_analysis(self, dataset_path: Path, export_path: Path) -> None:
# load and save the dataset
dataset = self._load_dataset(dataset_path)
dataset_size = len(dataset)
# export the processed dataset
dataset_export_path = export_path / "test_queries.json"
with dataset_export_path.open("w") as f:
dataset_serializable = [q.model_dump(mode="json") for q in dataset]
json.dump(dataset_serializable, f, indent=4)
# run the analysis
logger.info("Starting analysis of %d queries...", dataset_size)
logger.info("Using %d parallel workers", self.config.num_workers)
indexed_test_cases = [(i, test_case) for i, test_case in enumerate(dataset)]
indexed_results: dict[int, AnalysisSummary] = {}
with ThreadPoolExecutor(max_workers=self.config.num_workers) as executor:
future_to_index = {
executor.submit(
self._run_and_analyze_one_wrapper, test_case_with_index
): test_case_with_index[0]
for test_case_with_index in indexed_test_cases
}
# process completed tasks as they finish
for completed_count, future in enumerate(as_completed(future_to_index), 1):
try:
index, result = future.result()
indexed_results[index] = result
# update category stats on the fly
self.stats["all"].append(result)
for cat in result.categories or ["uncategorized"]:
self.stats[cat].append(result)
except Exception as e:
print(f"[{completed_count}/{dataset_size}] ✗ Error: {e}")
continue
# print progress with query info
question = (
result.question[:50] + "..."
if len(result.question) > 50
else result.question
)
status = "✓ Found" if result.found else "✗ Not found"
rank_info = f" (rank {result.rank})" if result.found else ""
print(
f"[{completed_count}/{dataset_size}] {status}{rank_info}: {question}"
)
# sort results by original order and build the metrics
self.results = [indexed_results[i] for i in sorted(indexed_results.keys())]
self._build_metrics()
def generate_summary(self) -> None:
logger.info("Generating summary...")
metrics_all = self.metrics.get("all")
if metrics_all is None:
logger.warning("Nothing to summarize")
return
total_queries = metrics_all.total_queries
found_count = metrics_all.found_count
print(
f"Total test queries: {total_queries}\n"
f"Ground truth found: {found_count} "
f"({found_count / total_queries * 100:.1f}%)\n"
f"Ground truth not found: {total_queries - found_count} "
f"({(total_queries - found_count) / total_queries * 100:.1f}%)"
)
with (
search_result_path.open("w") as search_file,
eval_path.open("w") as eval_file,
):
search_csv_writer = csv.writer(search_file)
eval_csv_writer = csv.writer(eval_file)
search_csv_writer.writerow(
["source", "query", "rank", "score", "doc_id", "chunk_id"]
if metrics_all.found_count > 0:
print(
"\nRank statistics (for found results):\n"
f" Average rank: {metrics_all.average_rank:.2f}\n"
f" Best rank: {metrics_all.best_rank}\n"
f" Worst rank: {metrics_all.worst_rank}\n"
)
eval_csv_writer.writerow(["query", *metric_names])
for k, acc in metrics_all.top_k_accuracy.items():
print(f" Top-{k} accuracy: {acc:.1f}%")
print(f"Average time taken: {metrics_all.average_time_taken:.2f}s")
for query in test_queries:
# search and write results
assert query.question_search is not None
search_chunks = search_one_query(
query.question_search,
multilingual_expansion,
document_index,
db_session,
config,
)
for rank, result in enumerate(search_chunks):
search_csv_writer.writerow(
if not self.config.search_only:
print(
"\nAnswer evaluation metrics:\n"
f" Average response relevancy: {metrics_all.average_response_relevancy:.2f}\n"
f" Average response groundedness: {metrics_all.average_response_groundedness:.2f}\n"
f" Average faithfulness: {metrics_all.average_faithfulness:.2f}"
)
def generate_detailed_report(self, export_path: Path) -> None:
logger.info("Generating detailed report...")
# save results for future inspection
results_json_path = export_path / "search_results.json"
with results_json_path.open("w") as f:
json.dump([r.model_dump(mode="json") for r in self.results], f, indent=4)
logger.info("Saved search results to %s", results_json_path)
# save results by category
csv_path = export_path / "results_by_category.csv"
with csv_path.open("w", newline="") as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(
[
"category",
"total_queries",
"found",
"percent_found",
"avg_rank_found",
*(
[
"search",
query.question_search,
rank,
result.score,
result.document_id,
result.chunk_id,
"avg_response_relevancy",
"avg_response_groundedness",
"avg_faithfulness",
]
)
rerank_chunks = []
if not config.skip_rerank:
# rerank and write results
rerank_chunks = rerank_one_query(
query.question, search_chunks, rerank_settings
)
for rank, result in enumerate(rerank_chunks):
search_csv_writer.writerow(
[
"rerank",
query.question,
rank,
result.score,
result.document_id,
result.chunk_id,
]
)
# evaluate and write results
truth_documents = [
doc
for truth in query.ground_truth
if (doc := get_corresponding_document(truth.doc_link, db_session))
if not self.config.search_only
else []
),
"avg_time_taken_sec",
]
metrics = evaluate_one_query(
search_chunks, rerank_chunks, truth_documents, config.eval_topk
)
for category, results in sorted(self.stats.items()):
if not results:
continue
metrics = self.metrics[category]
found_count = metrics.found_count
total_count = metrics.total_queries
accuracy = found_count / total_count * 100 if total_count > 0 else 0
print(
f"\n{category.upper()}:"
f" total queries: {total_count}\n"
f" found: {found_count} ({accuracy:.1f}%)"
)
metric_vals = [getattr(metrics, field) for field in metric_names]
eval_csv_writer.writerow([query.question, *metric_vals])
avg_rank = metrics.average_rank if metrics.found_count > 0 else None
if avg_rank is not None:
print(f" average rank (when found): {avg_rank:.2f}")
if not self.config.search_only:
print(
f" average response relevancy: {metrics.average_response_relevancy:.2f}\n"
f" average response groundedness: {metrics.average_response_groundedness:.2f}\n"
f" average faithfulness: {metrics.average_faithfulness:.2f}"
)
print(f" average time taken: {metrics.average_time_taken:.2f}s")
# add to aggregation
for category in ["all"] + query.categories:
for i, val in enumerate(metric_vals):
if val is not None:
aggregate_results[category][i].append(val)
# aggregate and write results
with aggregate_eval_path.open("w") as file:
aggregate_csv_writer = csv.writer(file)
aggregate_csv_writer.writerow(["category", *metric_names])
for category, agg_metrics in aggregate_results.items():
aggregate_csv_writer.writerow(
csv_writer.writerow(
[
category,
total_count,
found_count,
f"{accuracy:.1f}",
f"{avg_rank:.2f}" if avg_rank is not None else "",
*(
sum(metric) / len(metric) if metric else None
for metric in agg_metrics
[
f"{metrics.average_response_relevancy:.2f}",
f"{metrics.average_response_groundedness:.2f}",
f"{metrics.average_faithfulness:.2f}",
]
if not self.config.search_only
else []
),
f"{metrics.average_time_taken:.2f}",
]
)
logger.info("Saved category breakdown csv to %s", csv_path)
def generate_chart(self, export_path: Path) -> None:
logger.info("Generating search position chart...")
found_results = [r for r in self.results if r.found]
not_found_count = len([r for r in self.results if not r.found])
if not found_results and not_found_count == 0:
logger.warning("No results to chart")
return
# count occurrences at each rank position
rank_counts: dict[int, int] = defaultdict(int)
for result in found_results:
if result.rank is not None:
rank_counts[result.rank] += 1
# create the data for plotting
if found_results:
max_rank = max(rank_counts.keys())
positions = list(range(1, max_rank + 1))
counts = [rank_counts.get(pos, 0) for pos in positions]
else:
positions = []
counts = []
# add the "not found" bar on the far right
if not_found_count > 0:
# add some spacing between found positions and "not found"
not_found_position = (max(positions) + 2) if positions else 1
positions.append(not_found_position)
counts.append(not_found_count)
# create labels for x-axis
x_labels = [str(pos) for pos in positions[:-1]] + [
f"not found\n(>{self.config.max_search_results})"
]
else:
x_labels = [str(pos) for pos in positions]
# create the figure and bar chart
plt.figure(figsize=(14, 6))
# use different colors for found vs not found
colors = (
["#3498db"] * (len(positions) - 1) + ["#e74c3c"]
if not_found_count > 0
else ["#3498db"] * len(positions)
)
bars = plt.bar(
positions, counts, color=colors, alpha=0.7, edgecolor="black", linewidth=0.5
)
# customize the chart
plt.xlabel("Position in Search Results", fontsize=12)
plt.ylabel("Number of Ground Truth Documents", fontsize=12)
plt.title(
"Ground Truth Document Positions in Search Results",
fontsize=14,
fontweight="bold",
)
plt.grid(axis="y", alpha=0.3)
# add value labels on top of each bar
for bar, count in zip(bars, counts):
if count > 0:
plt.text(
bar.get_x() + bar.get_width() / 2,
bar.get_height() + 0.1,
str(count),
ha="center",
va="bottom",
fontweight="bold",
)
# set x-axis labels
plt.xticks(positions, x_labels, rotation=45 if not_found_count > 0 else 0)
# add legend if we have both found and not found
if not_found_count > 0 and found_results:
from matplotlib.patches import Patch
legend_elements = [
Patch(facecolor="#3498db", alpha=0.7, label="Found in Results"),
Patch(facecolor="#e74c3c", alpha=0.7, label="Not Found"),
]
plt.legend(handles=legend_elements, loc="upper right")
# make layout tight and save
plt.tight_layout()
chart_file = export_path / "search_position_chart.png"
plt.savefig(chart_file, dpi=300, bbox_inches="tight")
logger.info("Search position chart saved to: %s", chart_file)
plt.show()
def _load_dataset(self, dataset_path: Path) -> list[TestQuery]:
"""Load the test dataset from a JSON file and validate the ground truth documents."""
with dataset_path.open("r") as f:
dataset_raw: list[dict] = json.load(f)
dataset: list[TestQuery] = []
for datum in dataset_raw:
# validate the raw datum
try:
test_query = TestQuery(**datum)
except ValidationError as e:
logger.error("Incorrectly formatted query: %s", e)
continue
# in case the dataset was copied from the previous run export
if test_query.ground_truth_docids:
dataset.append(test_query)
continue
# validate and get the ground truth documents
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
for ground_truth in test_query.ground_truth:
doc = find_document(ground_truth, db_session)
if doc:
test_query.ground_truth_docids.append(doc.id)
if len(test_query.ground_truth_docids) == 0:
logger.warning(
"No ground truth documents found for query: %s, skipping...",
test_query.question,
)
continue
dataset.append(test_query)
return dataset
@retry(tries=3, delay=1, backoff=2)
def _perform_oneshot_qa(self, query: str) -> OneshotQAResult:
"""Perform a OneShot QA query against the Onyx API and time it."""
# create the thread message
messages = [ThreadMessage(message=query, sender=None, role=MessageType.USER)]
# create filters (empty to search all sources)
filters = IndexFilters(
source_type=None,
document_set=None,
time_cutoff=None,
tags=None,
access_control_list=None,
)
# create the OneShot QA request
qa_request = OneShotQARequest(
messages=messages,
prompt_id=0, # default prompt
persona_id=0, # default persona
retrieval_options=RetrievalDetails(
run_search=OptionalSearchSetting.ALWAYS,
real_time=True,
filters=filters,
enable_auto_detect_filters=False,
limit=self.config.max_search_results,
),
rerank_settings=self._rerank_settings,
return_contexts=True,
skip_gen_ai_answer_generation=self.config.search_only,
)
# send the request
response = None
try:
request_data = qa_request.model_dump()
start_time = time.monotonic()
response = requests.post(
url=f"{self.config.api_url}/query/answer-with-citation",
json=request_data,
headers=GENERAL_HEADERS,
timeout=self.config.request_timeout,
)
time_taken = time.monotonic() - start_time
response.raise_for_status()
result = OneShotQAResponse.model_validate(response.json())
# extract documents from the QA response
if result.docs:
top_documents = result.docs.top_documents
return OneshotQAResult(
time_taken=time_taken,
top_documents=top_documents,
answer=result.answer,
)
except RequestException as e:
raise RuntimeError(
f"OneShot QA failed for query '{query}': {e}."
f" Response: {response.json()}"
if response
else ""
)
raise RuntimeError(f"OneShot QA returned no documents for query {query}")
def _run_and_analyze_one(self, test_case: TestQuery) -> AnalysisSummary:
result = self._perform_oneshot_qa(test_case.question)
# compute rank
rank = self.config.max_search_results
found = False
ground_truths = set(test_case.ground_truth_docids)
for i, doc in enumerate(result.top_documents, 1):
if doc.document_id in ground_truths:
rank = i
found = True
break
# get the search contents
retrieved = search_docs_to_doc_contexts(result.top_documents)
# do answer evaluation
response_relevancy: float | None = None
response_groundedness: float | None = None
faithfulness: float | None = None
contexts = [c.content for c in retrieved[: self.config.max_answer_context]]
if not self.config.search_only:
if result.answer is None:
logger.error(
"No answer found for query: %s, skipping answer evaluation",
test_case.question,
)
else:
try:
ragas_result = ragas_evaluate(
question=test_case.question,
answer=result.answer,
contexts=contexts,
).scores[0]
response_relevancy = ragas_result["answer_relevancy"]
response_groundedness = ragas_result["nv_response_groundedness"]
faithfulness = ragas_result["faithfulness"]
except Exception as e:
logger.error(
"Error evaluating answer for query %s: %s",
test_case.question,
e,
)
return AnalysisSummary(
question=test_case.question,
categories=test_case.categories,
found=found,
rank=rank,
total_results=len(result.top_documents),
ground_truth_count=len(test_case.ground_truth_docids),
answer=result.answer,
response_relevancy=response_relevancy,
response_groundedness=response_groundedness,
faithfulness=faithfulness,
retrieved=retrieved,
time_taken=result.time_taken,
)
def _run_and_analyze_one_wrapper(
self, test_case_with_index: tuple[int, TestQuery]
) -> tuple[int, AnalysisSummary]:
index, test_case = test_case_with_index
return index, self._run_and_analyze_one(test_case)
def _compute_combined_metrics(
self, results: list[AnalysisSummary]
) -> CombinedMetrics:
"""Aggregate analysis summaries into CombinedMetrics."""
total_queries = len(results)
found_ranks = [r.rank for r in results if r.found and r.rank is not None]
found_count = len(found_ranks)
best_rank = 0
worst_rank = 0
average_rank = 0.0
response_relevancy = 0.0
response_groundedness = 0.0
faithfulness = 0.0
if found_ranks:
best_rank = min(found_ranks)
worst_rank = max(found_ranks)
average_rank = sum(found_ranks) / found_count
if not self.config.search_only:
scores = [
r.response_relevancy
for r in results
if r.response_relevancy is not None
]
response_relevancy = sum(scores) / len(scores)
scores = [
r.response_groundedness
for r in results
if r.response_groundedness is not None
]
response_groundedness = sum(scores) / len(scores)
scores = [r.faithfulness for r in results if r.faithfulness is not None]
faithfulness = sum(scores) / len(scores)
top_k_accuracy: dict[int, float] = {}
for k in TOP_K_LIST:
hits = sum(1 for rank in found_ranks if rank <= k)
top_k_accuracy[k] = (hits / total_queries * 100) if total_queries else 0.0
times = [r.time_taken for r in results if r.time_taken is not None]
avg_time_taken = sum(times) / len(times) if times else 0.0
return CombinedMetrics(
total_queries=total_queries,
found_count=found_count,
best_rank=best_rank,
worst_rank=worst_rank,
average_rank=average_rank,
top_k_accuracy=top_k_accuracy,
average_response_relevancy=response_relevancy,
average_response_groundedness=response_groundedness,
average_faithfulness=faithfulness,
average_time_taken=avg_time_taken,
)
def _build_metrics(self) -> None:
self.metrics = {
cat: self._compute_combined_metrics(res_list)
for cat, res_list in self.stats.items()
}
def _get_rerank_settings(self) -> RerankingDetails | None:
"""Fetch the tenant's reranking settings from the database."""
try:
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
search_settings = get_current_search_settings(db_session)
if search_settings:
rerank_settings = RerankingDetails.from_db_model(search_settings)
if not self.config.rerank_all:
return rerank_settings
# override the num_rerank to the eval limit
rerank_settings = rerank_settings.model_copy(
update={"num_rerank": self.config.max_search_results}
)
return rerank_settings
except Exception as e:
logger.warning("Could not load rerank settings from DB: %s", e)
return None
def run_search_eval(
dataset_path: Path,
config: EvalConfig,
tenant_id: str | None,
) -> None:
if not config.search_only and not os.environ.get("OPENAI_API_KEY"):
raise RuntimeError("OPENAI_API_KEY is required for answer evaluation")
# check onyx is running
try:
response = requests.get(
f"{config.api_url}/health", timeout=config.request_timeout
)
response.raise_for_status()
except RequestException as e:
raise RuntimeError(f"Could not connect to Onyx API: {e}")
logger.info("Onyx API is running")
# create the export folder
export_folder = current_dir / datetime.now().strftime("eval-%Y-%m-%d-%H-%M-%S")
export_path = Path(export_folder)
export_path.mkdir(parents=True, exist_ok=True)
logger.info("Created export folder: %s", export_path)
# run the search eval
analyzer = SearchAnswerAnalyzer(config=config, tenant_id=tenant_id)
analyzer.run_analysis(dataset_path, export_path)
analyzer.generate_summary()
analyzer.generate_detailed_report(export_path)
analyzer.generate_chart(export_path)
if __name__ == "__main__":
if MULTI_TENANT:
raise ValueError("Multi-tenant is not supported currently")
import argparse
current_dir = Path(__file__).parent
parser = argparse.ArgumentParser(description="Run search quality evaluation.")
parser.add_argument(
"-d",
"--dataset",
type=Path,
default=current_dir / "test_queries.json",
help="Path to the test-set JSON file (default: %(default)s).",
)
parser.add_argument(
"-n",
"--num_search",
type=int,
default=50,
help="Maximum number of search results to check per query (default: %(default)s).",
)
parser.add_argument(
"-a",
"--num_answer",
type=int,
default=25,
help="Maximum number of search results to use for answer evaluation (default: %(default)s).",
)
parser.add_argument(
"-w",
"--workers",
type=int,
default=10,
help="Number of parallel search requests (default: %(default)s).",
)
parser.add_argument(
"-q",
"--timeout",
type=int,
default=120,
help="Request timeout in seconds (default: %(default)s).",
)
parser.add_argument(
"-e",
"--api_endpoint",
type=str,
default="http://127.0.0.1:8080",
help="Base URL of the Onyx API server (default: %(default)s).",
)
parser.add_argument(
"-s",
"--search_only",
action="store_true",
default=False,
help="Only perform search and not answer evaluation (default: %(default)s).",
)
parser.add_argument(
"-r",
"--rerank_all",
action="store_true",
default=False,
help="Always rerank all search results (default: %(default)s).",
)
parser.add_argument(
"-t",
"--tenant_id",
type=str,
default=None,
help="Tenant ID to use for the evaluation (default: %(default)s).",
)
args = parser.parse_args()
SqlEngine.init_engine(
pool_size=POSTGRES_API_SERVER_POOL_SIZE,
@@ -153,9 +694,21 @@ if __name__ == "__main__":
)
try:
run_search_eval()
run_search_eval(
args.dataset,
EvalConfig(
max_search_results=args.num_search,
max_answer_context=args.num_answer,
num_workers=args.workers,
request_timeout=args.timeout,
api_url=args.api_endpoint,
search_only=args.search_only,
rerank_all=args.rerank_all,
),
args.tenant_id,
)
except Exception as e:
logger.error(f"Error running search evaluation: {e}")
raise e
logger.error("Unexpected error during search evaluation: %s", e)
raise
finally:
SqlEngine.reset_engine()

View File

@@ -1,16 +0,0 @@
# Search Parameters
HYBRID_ALPHA: 0.5
HYBRID_ALPHA_KEYWORD: 0.4
DOC_TIME_DECAY: 0.5
NUM_RETURNED_HITS: 50 # Setting to a higher value will improve evaluation quality but increase reranking time
RANK_PROFILE: 'semantic'
OFFSET: 0
TITLE_CONTENT_RATIO: 0.1
USER_EMAIL: null # User email to use for testing, modifies access control list, null means only public files
# Evaluation parameters
SKIP_RERANK: false # Whether to skip reranking, reranking must be enabled to evaluate the search results
EVAL_TOPK: 5 # Number of top results from the searcher and reranker to evaluate, lower means stricter evaluation
# Export file, will export a csv file with the results and a json file with the parameters
EXPORT_FOLDER: "eval-%Y-%m-%d-%H-%M-%S"

View File

@@ -3,20 +3,18 @@
"question": "What is Onyx?",
"ground_truth": [
{
"doc_source": "Web",
"doc_source": "web",
"doc_link": "https://docs.onyx.app/more/use_cases/overview"
},
{
"doc_source": "Web",
"doc_source": "web",
"doc_link": "https://docs.onyx.app/more/use_cases/ai_platform"
}
],
"categories": [
"keyword",
"broad"
"broad",
"easy"
]
},
{
"question": "What is the meaning of life?"
}
]

View File

@@ -1,75 +0,0 @@
from datetime import datetime
from pathlib import Path
import yaml
from pydantic import BaseModel
from onyx.agents.agent_search.shared_graph_utils.models import QueryExpansionType
from onyx.configs.chat_configs import DOC_TIME_DECAY
from onyx.configs.chat_configs import HYBRID_ALPHA
from onyx.configs.chat_configs import HYBRID_ALPHA_KEYWORD
from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.chat_configs import TITLE_CONTENT_RATIO
from onyx.utils.logger import setup_logger
logger = setup_logger(__name__)
class SearchEvalConfig(BaseModel):
hybrid_alpha: float
hybrid_alpha_keyword: float
doc_time_decay: float
num_returned_hits: int
rank_profile: QueryExpansionType
offset: int
title_content_ratio: float
user_email: str | None
skip_rerank: bool
eval_topk: int
export_folder: str
def load_config() -> SearchEvalConfig:
"""Loads the search evaluation configs from the config file."""
# open the config file
current_dir = Path(__file__).parent
config_path = current_dir / "search_eval_config.yaml"
if not config_path.exists():
raise FileNotFoundError(f"Search eval config file not found at {config_path}")
with config_path.open("r") as file:
config_raw = yaml.safe_load(file)
# create the export folder
export_folder = config_raw.get("EXPORT_FOLDER", "eval-%Y-%m-%d-%H-%M-%S")
export_folder = datetime.now().strftime(export_folder)
export_path = Path(export_folder)
export_path.mkdir(parents=True, exist_ok=True)
logger.info(f"Created export folder: {export_path}")
# create the config
config = SearchEvalConfig(
hybrid_alpha=config_raw.get("HYBRID_ALPHA", HYBRID_ALPHA),
hybrid_alpha_keyword=config_raw.get(
"HYBRID_ALPHA_KEYWORD", HYBRID_ALPHA_KEYWORD
),
doc_time_decay=config_raw.get("DOC_TIME_DECAY", DOC_TIME_DECAY),
num_returned_hits=config_raw.get("NUM_RETURNED_HITS", NUM_RETURNED_HITS),
rank_profile=config_raw.get("RANK_PROFILE", QueryExpansionType.SEMANTIC),
offset=config_raw.get("OFFSET", 0),
title_content_ratio=config_raw.get("TITLE_CONTENT_RATIO", TITLE_CONTENT_RATIO),
user_email=config_raw.get("USER_EMAIL"),
skip_rerank=config_raw.get("SKIP_RERANK", False),
eval_topk=config_raw.get("EVAL_TOPK", 5),
export_folder=export_folder,
)
logger.info(f"Using search parameters: {config}")
# export the config
config_file = export_path / "search_eval_config.yaml"
with config_file.open("w") as file:
config_dict = config.model_dump(mode="python")
config_dict["rank_profile"] = config.rank_profile.value
yaml.dump(config_dict, file, sort_keys=False)
logger.info(f"Exported config to {config_file}")
return config

View File

@@ -1,166 +0,0 @@
import json
from pathlib import Path
from typing import cast
from typing import Optional
from langgraph.types import StreamWriter
from pydantic import BaseModel
from pydantic import ValidationError
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.chat.models import PromptConfig
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.persona import get_persona_by_id
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.interfaces import LLM
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
logger = setup_logger()
class GroundTruth(BaseModel):
doc_source: str
doc_link: str
class TestQuery(BaseModel):
question: str
question_search: Optional[str] = None
ground_truth: list[GroundTruth] = []
categories: list[str] = []
def load_test_queries() -> list[TestQuery]:
"""
Loads the test queries from the test_queries.json file.
If `question_search` is missing, it will use the tool-calling LLM to generate it.
"""
# open test queries file
current_dir = Path(__file__).parent
test_queries_path = current_dir / "test_queries.json"
logger.info(f"Loading test queries from {test_queries_path}")
if not test_queries_path.exists():
raise FileNotFoundError(f"Test queries file not found at {test_queries_path}")
with test_queries_path.open("r") as f:
test_queries_raw: list[dict] = json.load(f)
# setup llm for question_search generation
with get_session_with_current_tenant() as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, None, db_session)
llm, _ = get_llms_for_persona(persona)
prompt_config = PromptConfig.from_model(persona.prompts[0])
search_tool = SearchToolOverride()
tool_call_supported = explicit_tool_calling_supported(
llm.config.model_provider, llm.config.model_name
)
# validate keys and generate question_search if missing
test_queries: list[TestQuery] = []
for query_raw in test_queries_raw:
try:
test_query = TestQuery(**query_raw)
except ValidationError as e:
logger.error(f"Incorrectly formatted query: {e}")
continue
if test_query.question_search is None:
test_query.question_search = _modify_one_query(
query=test_query.question,
llm=llm,
prompt_config=prompt_config,
tool=search_tool,
tool_call_supported=tool_call_supported,
)
test_queries.append(test_query)
return test_queries
def export_test_queries(test_queries: list[TestQuery], export_path: Path) -> None:
"""Exports the test queries to a JSON file."""
logger.info(f"Exporting test queries to {export_path}")
with export_path.open("w") as f:
json.dump(
[query.model_dump() for query in test_queries],
f,
indent=4,
)
class SearchToolOverride(SearchTool):
def __init__(self) -> None:
# do nothing, only class variables are required for the functions we call
pass
warned = False
def _modify_one_query(
query: str,
llm: LLM,
prompt_config: PromptConfig,
tool: SearchTool,
tool_call_supported: bool,
writer: StreamWriter = lambda _: None,
) -> str:
global warned
if not warned:
logger.warning(
"Generating question_search. If you do not save the question_search, "
"it will be generated again on the next run, potentially altering the search results."
)
warned = True
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=query,
prompt_config=prompt_config,
files=[],
single_message_history=None,
),
system_message=default_build_system_message(prompt_config, llm.config),
message_history=[],
llm_config=llm.config,
raw_user_query=query,
raw_user_uploaded_files=[],
single_message_history=None,
)
if tool_call_supported:
prompt = prompt_builder.build()
tool_definition = tool.tool_definition()
stream = llm.stream(
prompt=prompt,
tools=[tool_definition],
tool_choice="required",
structured_response_format=None,
)
tool_message = process_llm_stream(
messages=stream,
should_stream_answer=False,
writer=writer,
)
return (
tool_message.tool_calls[0]["args"]["query"]
if tool_message.tool_calls
else query
)
history = prompt_builder.get_message_history()
return cast(
dict[str, str],
tool.get_args_for_non_tool_calling_llm(
query=query,
history=history,
llm=llm,
force_run=True,
),
)["query"]

View File

@@ -1,94 +0,0 @@
from typing import Optional
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.context.search.models import InferenceChunk
from onyx.db.models import Document
from onyx.utils.logger import setup_logger
from tests.regression.search_quality.util_retrieve import group_by_documents
logger = setup_logger(__name__)
class Metrics(BaseModel):
# computed if ground truth is provided
ground_truth_ratio_topk: Optional[float] = None
ground_truth_avg_rank_delta: Optional[float] = None
# computed if reranked results are provided
soft_truth_ratio_topk: Optional[float] = None
soft_truth_avg_rank_delta: Optional[float] = None
metric_names = list(Metrics.model_fields.keys())
def get_corresponding_document(
doc_link: str, db_session: Session
) -> Optional[Document]:
"""Get the corresponding document from the database."""
doc_filter = db_session.query(Document).filter(Document.link == doc_link)
count = doc_filter.count()
if count == 0:
logger.warning(f"Could not find document with link {doc_link}, ignoring")
return None
if count > 1:
logger.warning(f"Found multiple documents with link {doc_link}, using first")
return doc_filter.first()
def evaluate_one_query(
search_chunks: list[InferenceChunk],
rerank_chunks: list[InferenceChunk],
true_documents: list[Document],
topk: int,
) -> Metrics:
"""Computes metrics for the search results, relative to the ground truth and reranked results."""
metrics_dict: dict[str, float] = {}
search_documents = group_by_documents(search_chunks)
search_ranks = {docid: rank for rank, docid in enumerate(search_documents)}
search_ranks_topk = {
docid: rank for rank, docid in enumerate(search_documents[:topk])
}
true_ranks = {doc.id: rank for rank, doc in enumerate(true_documents)}
if true_documents:
metrics_dict["ground_truth_ratio_topk"] = _compute_ratio(
search_ranks_topk, true_ranks
)
metrics_dict["ground_truth_avg_rank_delta"] = _compute_avg_rank_delta(
search_ranks, true_ranks
)
if rerank_chunks:
# build soft truth out of ground truth + reranked results, up to topk
soft_ranks = true_ranks
for docid in group_by_documents(rerank_chunks):
if len(soft_ranks) >= topk:
break
if docid not in soft_ranks:
soft_ranks[docid] = len(soft_ranks)
metrics_dict["soft_truth_ratio_topk"] = _compute_ratio(
search_ranks_topk, soft_ranks
)
metrics_dict["soft_truth_avg_rank_delta"] = _compute_avg_rank_delta(
search_ranks, soft_ranks
)
return Metrics(**metrics_dict)
def _compute_ratio(search_ranks: dict[str, int], true_ranks: dict[str, int]) -> float:
return len(set(search_ranks) & set(true_ranks)) / len(true_ranks)
def _compute_avg_rank_delta(
search_ranks: dict[str, int], true_ranks: dict[str, int]
) -> float:
out = len(search_ranks)
return sum(
abs(search_ranks.get(docid, out) - rank) for docid, rank in true_ranks.items()
) / len(true_ranks)

View File

@@ -1,88 +0,0 @@
from sqlalchemy.orm import Session
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import RerankingDetails
from onyx.context.search.postprocessing.postprocessing import semantic_reranking
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.context.search.utils import remove_stop_words_and_punctuation
from onyx.document_index.interfaces import DocumentIndex
from onyx.utils.logger import setup_logger
from tests.regression.search_quality.util_config import SearchEvalConfig
logger = setup_logger(__name__)
def search_one_query(
question_search: str,
multilingual_expansion: list[str],
document_index: DocumentIndex,
db_session: Session,
config: SearchEvalConfig,
) -> list[InferenceChunk]:
"""Uses the search pipeline to retrieve relevant chunks for the given query."""
# the retrieval preprocessing is fairly stripped down so the query doesn't unexpectedly change
query_embedding = get_query_embedding(question_search, db_session)
all_query_terms = question_search.split()
processed_keywords = (
remove_stop_words_and_punctuation(all_query_terms)
if not multilingual_expansion
else all_query_terms
)
is_keyword = query_analysis(question_search)[0]
hybrid_alpha = config.hybrid_alpha_keyword if is_keyword else config.hybrid_alpha
access_control_list = ["PUBLIC"]
if config.user_email:
access_control_list.append(f"user_email:{config.user_email}")
filters = IndexFilters(
tags=[],
user_file_ids=[],
user_folder_ids=[],
access_control_list=access_control_list,
tenant_id=None,
)
results = document_index.hybrid_retrieval(
query=question_search,
query_embedding=query_embedding,
final_keywords=processed_keywords,
filters=filters,
hybrid_alpha=hybrid_alpha,
time_decay_multiplier=config.doc_time_decay,
num_to_retrieve=config.num_returned_hits,
ranking_profile_type=config.rank_profile,
offset=config.offset,
title_content_ratio=config.title_content_ratio,
)
return [result.to_inference_chunk() for result in results]
def rerank_one_query(
question: str,
retrieved_chunks: list[InferenceChunk],
rerank_settings: RerankingDetails,
) -> list[InferenceChunk]:
"""Uses the reranker to rerank the retrieved chunks for the given query."""
rerank_settings.num_rerank = len(retrieved_chunks)
return semantic_reranking(
query_str=question,
rerank_settings=rerank_settings,
chunks=retrieved_chunks,
rerank_metrics_callback=None,
)[0]
def group_by_documents(chunks: list[InferenceChunk]) -> list[str]:
"""Groups a sorted list of chunks into a sorted list of document ids."""
seen_docids: set[str] = set()
retrieved_docids: list[str] = []
for chunk in chunks:
if chunk.document_id not in seen_docids:
seen_docids.add(chunk.document_id)
retrieved_docids.append(chunk.document_id)
return retrieved_docids

View File

@@ -0,0 +1,78 @@
from ragas import evaluate # type: ignore
from ragas import EvaluationDataset # type: ignore
from ragas import SingleTurnSample # type: ignore
from ragas.dataset_schema import EvaluationResult # type: ignore
from ragas.metrics import Faithfulness # type: ignore
from ragas.metrics import ResponseGroundedness # type: ignore
from ragas.metrics import ResponseRelevancy # type: ignore
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import SavedSearchDoc
from onyx.db.models import Document
from onyx.prompts.prompt_utils import build_doc_context_str
from onyx.utils.logger import setup_logger
from tests.regression.search_quality.models import GroundTruth
from tests.regression.search_quality.models import RetrievedDocument
logger = setup_logger(__name__)
def find_document(ground_truth: GroundTruth, db_session: Session) -> Document | None:
"""Find a document by its link."""
# necessary preprocessing of links
doc_link = ground_truth.doc_link
if ground_truth.doc_source == DocumentSource.GOOGLE_DRIVE:
if "/edit" in doc_link:
doc_link = doc_link.split("/edit", 1)[0]
elif "/view" in doc_link:
doc_link = doc_link.split("/view", 1)[0]
elif ground_truth.doc_source == DocumentSource.FIREFLIES:
doc_link = doc_link.split("?", 1)[0]
docs = db_session.query(Document).filter(Document.link.ilike(f"{doc_link}%")).all()
if len(docs) == 0:
logger.warning("Could not find ground truth document: %s", doc_link)
return None
elif len(docs) > 1:
logger.warning(
"Found multiple ground truth documents: %s, using the first one: %s",
doc_link,
docs[0].id,
)
return docs[0]
def search_docs_to_doc_contexts(docs: list[SavedSearchDoc]) -> list[RetrievedDocument]:
return [
RetrievedDocument(
document_id=doc.document_id,
content=build_doc_context_str(
semantic_identifier=doc.semantic_identifier,
source_type=doc.source_type,
content=doc.blurb, # getting the full content is painful
metadata_dict=doc.metadata,
updated_at=doc.updated_at,
ind=ind,
include_metadata=True,
),
)
for ind, doc in enumerate(docs)
]
def ragas_evaluate(question: str, answer: str, contexts: list[str]) -> EvaluationResult:
sample = SingleTurnSample(
user_input=question,
retrieved_contexts=contexts,
response=answer,
)
dataset = EvaluationDataset([sample])
return evaluate(
dataset,
metrics=[
ResponseRelevancy(),
ResponseGroundedness(),
Faithfulness(),
],
)