Compare commits

...

8 Commits

Author SHA1 Message Date
Wenxi Onyx
d973363c5c embedding changes for hackathon 2025-07-02 16:07:58 -07:00
joachim-danswer
013bed3157 fix 2025-06-30 15:19:42 -07:00
joachim-danswer
289f27c43a updates 2025-06-30 15:06:12 -07:00
joachim-danswer
736a9bd332 erase history 2025-06-30 09:01:23 -07:00
joachim-danswer
8bcad415bb nit 2025-06-30 08:16:43 -07:00
joachim-danswer
93e6e4a089 mypy nits 2025-06-30 07:49:55 -07:00
joachim-danswer
ed0062dce0 fix 2025-06-30 02:45:03 -07:00
joachim-danswer
6e8bf3120c hackathon v1 changes 2025-06-30 01:39:36 -07:00
9 changed files with 669 additions and 21 deletions

View File

@@ -51,6 +51,7 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
else:
continue
history_segments.append(f"{role}:\n {msg.content}\n\n")
return "\n".join(history_segments)

View File

@@ -1,7 +1,12 @@
import csv
import json
import os
from collections import defaultdict
from collections.abc import Callable
from pathlib import Path
from uuid import UUID
from langchain_core.messages import HumanMessage
from sqlalchemy.orm import Session
from onyx.agents.agent_search.models import GraphConfig
@@ -11,6 +16,9 @@ from onyx.agents.agent_search.models import GraphSearchConfig
from onyx.agents.agent_search.models import GraphTooling
from onyx.agents.agent_search.run_graph import run_agent_search_graph
from onyx.agents.agent_search.run_graph import run_basic_graph
from onyx.agents.agent_search.run_graph import (
run_basic_graph as run_hackathon_graph,
) # You can create your own graph
from onyx.agents.agent_search.run_graph import run_dc_graph
from onyx.agents.agent_search.run_graph import run_kb_graph
from onyx.chat.models import AgentAnswerPiece
@@ -22,9 +30,11 @@ from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import SubQuestionKey
from onyx.chat.models import ToolCallFinalResult
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
from onyx.configs.constants import BASIC_KEY
from onyx.context.search.models import RerankingDetails
@@ -44,6 +54,190 @@ logger = setup_logger()
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
"""
Calculate the score for a given position.
"""
if pos > max_acceptable_pos:
return 0
elif pos == 1:
return 1
elif pos == 2:
return 0.8
else:
return 4 / (pos + 5)
def _clean_doc_id_link(doc_link: str) -> str:
"""
Clean the google doc link.
"""
if "google.com" in doc_link:
if "/edit" in doc_link:
return "/edit".join(doc_link.split("/edit")[:-1])
elif "/view" in doc_link:
return "/view".join(doc_link.split("/view")[:-1])
else:
return doc_link
if "app.fireflies.ai" in doc_link:
return "?".join(doc_link.split("?")[:-1])
return doc_link
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
"""
Get the score of a document from the document results.
"""
match_pos = None
for pos, comp_doc in enumerate(doc_results, start=1):
clear_doc_id = _clean_doc_id_link(doc_id)
clear_comp_doc = _clean_doc_id_link(comp_doc)
if clear_doc_id == clear_comp_doc:
match_pos = pos
if match_pos is None:
return 0.0
return _calc_score_for_pos(match_pos)
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH):
"""
Append an empty line to the CSV file.
"""
_append_answer_to_csv("", "", csv_path)
def _append_ground_truth_to_csv(
query: str,
ground_truth_docs: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for doc_id in ground_truth_docs:
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
logger.debug("Appended score to csv file")
def _append_score_to_csv(
query: str,
score: float,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the score to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", "", score])
logger.debug("Appended score to csv file")
def _append_search_results_to_csv(
query: str,
doc_results: list[str],
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append the search results to the CSV file.
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for pos, doc in enumerate(doc_results, start=1):
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
logger.debug("Appended search results to csv file")
def _append_answer_to_csv(
query: str,
answer: str,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
writer.writerow([query, "", "", answer, ""])
logger.debug("Appended answer to csv file")
class Answer:
def __init__(
self,
@@ -134,6 +328,9 @@ class Answer:
@property
def processed_streamed_output(self) -> AnswerStream:
_HACKATHON_TEST_EXECUTION = False
if self._processed_stream is not None:
yield from self._processed_stream
return
@@ -154,22 +351,124 @@ class Answer:
)
):
run_langgraph = run_dc_graph
elif (
self.graph_config.inputs.persona
and self.graph_config.inputs.persona.description.startswith(
"Hackathon Test"
)
):
_HACKATHON_TEST_EXECUTION = True
run_langgraph = run_hackathon_graph
else:
run_langgraph = run_basic_graph
stream = run_langgraph(
self.graph_config,
)
if _HACKATHON_TEST_EXECUTION:
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
# Enable search-only mode for hackathon test execution
self.graph_config.behavior.skip_gen_ai_answer_generation = True
# Disable reranking for faster processing
self.graph_config.behavior.allow_agent_reranking = False
self.graph_config.behavior.allow_refinement = False
# Disable query rewording for more predictable results
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
if input_data.startswith("["):
input_type = "json"
input_list = json.loads(input_data)
else:
input_type = "list"
input_list = input_data.split(";")
num_examples_with_ground_truth = 0
total_score = 0.0
for question_num, question_data in enumerate(input_list):
ground_truth_docs = None
if input_type == "json":
question = question_data["question"]
ground_truth = question_data.get("ground_truth")
if ground_truth:
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
logger.info(f"Question {question_num}: {question}")
_append_ground_truth_to_csv(question, ground_truth_docs)
else:
continue
else:
question = question_data
self.graph_config.inputs.prompt_builder.raw_user_query = question
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
HumanMessage(
content=question, additional_kwargs={}, response_metadata={}
),
2,
)
self.graph_config.tooling.force_use_tool.force_use = True
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
llm_answer_segments: list[str] = []
doc_results: list[str] | None = None
for answer_piece in processed_stream:
if isinstance(answer_piece, OnyxAnswerPiece):
llm_answer_segments.append(answer_piece.answer_piece or "")
elif isinstance(answer_piece, ToolCallFinalResult):
doc_results = [x.get("link") for x in answer_piece.tool_result]
if doc_results:
_append_search_results_to_csv(question, doc_results)
_append_answer_to_csv(question, "".join(llm_answer_segments))
if ground_truth_docs and doc_results:
num_examples_with_ground_truth += 1
doc_score = 0.0
for doc_id in ground_truth_docs:
doc_score += _get_doc_score(doc_id, doc_results)
_append_score_to_csv(question, doc_score)
total_score += doc_score
self._processed_stream = processed_stream
if num_examples_with_ground_truth > 0:
comprehensive_score = total_score / num_examples_with_ground_truth
else:
comprehensive_score = 0
_append_empty_line()
_append_score_to_csv(question, comprehensive_score)
else:
stream = run_langgraph(
self.graph_config,
)
processed_stream = []
for packet in stream:
if self.is_cancelled():
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
yield packet
break
processed_stream.append(packet)
yield packet
break
processed_stream.append(packet)
yield packet
self._processed_stream = processed_stream
self._processed_stream = processed_stream
@property
def llm_answer(self) -> str:

View File

@@ -1012,6 +1012,7 @@ def stream_chat_message_objects(
tools=tools,
db_session=db_session,
use_agentic_search=new_msg_req.use_agentic_search,
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
)
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(

View File

@@ -187,8 +187,12 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if self.new_messages_and_token_cnts:
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
if (
self.new_messages_and_token_cnts
and isinstance(self.user_message_and_token_cnt[0].content, str)
and self.user_message_and_token_cnt[0].content.startswith("Refer")
):
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -787,3 +787,7 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
# Forcing Vespa Language
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
)

View File

@@ -1,10 +1,13 @@
import os
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import cast
from typing import List
import openai
import requests
from pydantic import BaseModel
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
@@ -48,6 +51,133 @@ _FIREFLIES_API_QUERY = """
ONE_MINUTE = 60
class DocumentClassificationResult(BaseModel):
categories: list[str]
entities: list[str]
def _extract_categories_and_entities(
sections: list[TextSection | ImageSection],
) -> dict[str, list[str]]:
"""Extract categories and entities from document sections with retry logic."""
import time
import random
prompt = """
Analyze this document, classify it with categories, and extract important entities.
CATEGORIES:
Create up to 5 simple categories that best capture what this document is about. Consider categories within:
- Document type (e.g., Manual, Report, Email, Transcript, etc.)
- Content domain (e.g., Technical, Financial, HR, Marketing, etc.)
- Purpose (e.g., Training, Reference, Announcement, Analysis, etc.)
- Industry/Topic area (e.g., Software Development, Sales, Legal, etc.)
Be creative and specific. Use clear, descriptive terms that someone searching for this document might use.
Categories should be up to 2 words each.
ENTITIES:
Extract up to 5 important proper nouns, such as:
- Company names (e.g., Microsoft, Google, Acme Corp)
- Product names (e.g., Office 365, Salesforce, iPhone)
- People's names (e.g. John, Jane, Ahmed, Wenjie, etc.)
- Department names (e.g., Engineering, Marketing, HR)
- Project names (e.g., Project Alpha, Migration 2024)
- Technology names (e.g., PostgreSQL, React, AWS)
- Location names (e.g., New York Office, Building A)
"""
# Retry configuration
max_retries = 3
base_delay = 1.0 # seconds
backoff_factor = 2.0
for attempt in range(max_retries + 1):
try:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set, skipping metadata extraction")
return {"categories": [], "entities": []}
client = openai.OpenAI(api_key=api_key)
# Combine all section text
document_text = "\n\n".join(
[
section.text
for section in sections
if isinstance(section, TextSection) and section.text.strip()
]
)
# Skip if no text content
if not document_text.strip():
logger.debug("No text content found, skipping metadata extraction")
return {"categories": [], "entities": []}
# Truncate very long documents to avoid token limits
max_chars = 50000 # Roughly 12k tokens
if len(document_text) > max_chars:
document_text = document_text[:max_chars] + "..."
logger.debug(f"Truncated document text to {max_chars} characters")
response = client.responses.parse(
model="o3",
input=[
{
"role": "system",
"content": "Extract categories and entities from the document.",
},
{
"role": "user",
"content": prompt + "\n\nDOCUMENT: " + document_text,
},
],
text_format=DocumentClassificationResult,
)
classification_result = response.output_parsed
result = {
"categories": classification_result.categories,
"entities": classification_result.entities,
}
logger.debug(f"Successfully extracted metadata: {result}")
return result
except Exception as e:
attempt_num = attempt + 1
is_last_attempt = attempt == max_retries
# Log the error
if is_last_attempt:
logger.error(
f"Failed to extract categories and entities after {max_retries + 1} attempts: {e}"
)
else:
logger.warning(
f"Attempt {attempt_num} failed to extract metadata: {e}. Retrying..."
)
# If this is the last attempt, return empty results
if is_last_attempt:
return {"categories": [], "entities": []}
# Calculate delay with exponential backoff and jitter
delay = base_delay * (backoff_factor**attempt)
jitter = random.uniform(0.1, 0.3) # Add 10-30% jitter
total_delay = delay + jitter
logger.debug(
f"Waiting {total_delay:.2f} seconds before retry {attempt_num + 1}"
)
time.sleep(total_delay)
# Should never reach here, but just in case
return {"categories": [], "entities": []}
def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections: List[TextSection] = []
current_speaker_name = None
@@ -96,12 +226,19 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
if participant != meeting_organizer_email and participant:
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
# Extract categories and entities from transcript and store in metadata
categories_and_entities = _extract_categories_and_entities(sections)
metadata = {
"categories": categories_and_entities.get("categories", []),
"entities": categories_and_entities.get("entities", []),
}
return Document(
id=fireflies_id,
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},
metadata=metadata,
doc_updated_at=meeting_date,
primary_owners=organizer_email_user_info,
secondary_owners=meeting_participants_email_list,

View File

@@ -1,9 +1,11 @@
import io
import os
from collections.abc import Callable
from datetime import datetime
from typing import Any
from typing import cast
import openai
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import MediaIoBaseDownload # type: ignore
from pydantic import BaseModel
@@ -45,6 +47,12 @@ from onyx.utils.variable_functionality import noop_fallback
logger = setup_logger()
class DocumentClassificationResult(BaseModel):
categories: list[str]
entities: list[str]
# This is not a standard valid unicode char, it is used by the docs advanced API to
# represent smart chips (elements like dates and doc links).
SMART_CHIP_CHAR = "\ue907"
@@ -406,6 +414,128 @@ def convert_drive_item_to_document(
return first_error
def _extract_categories_and_entities(
sections: list[TextSection | ImageSection],
) -> dict[str, list[str]]:
"""Extract categories and entities from document sections with retry logic."""
import time
import random
prompt = """
Analyze this document, classify it with categories, and extract important entities.
CATEGORIES:
Create up to 5 simple categories that best capture what this document is about. Consider categories within:
- Document type (e.g., Manual, Report, Email, Transcript, etc.)
- Content domain (e.g., Technical, Financial, HR, Marketing, etc.)
- Purpose (e.g., Training, Reference, Announcement, Analysis, etc.)
- Industry/Topic area (e.g., Software Development, Sales, Legal, etc.)
Be creative and specific. Use clear, descriptive terms that someone searching for this document might use.
Categories should be up to 2 words each.
ENTITIES:
Extract up to 5 important proper nouns, such as:
- Company names (e.g., Microsoft, Google, Acme Corp)
- Product names (e.g., Office 365, Salesforce, iPhone)
- People's names (e.g. John, Jane, Ahmed, Wenjie, etc.)
- Department names (e.g., Engineering, Marketing, HR)
- Project names (e.g., Project Alpha, Migration 2024)
- Technology names (e.g., PostgreSQL, React, AWS)
- Location names (e.g., New York Office, Building A)
"""
# Retry configuration
max_retries = 3
base_delay = 1.0 # seconds
backoff_factor = 2.0
for attempt in range(max_retries + 1):
try:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
logger.warning("OPENAI_API_KEY not set, skipping metadata extraction")
return {"categories": [], "entities": []}
client = openai.OpenAI(api_key=api_key)
# Combine all section text
document_text = "\n\n".join(
[
section.text
for section in sections
if isinstance(section, TextSection) and section.text.strip()
]
)
# Skip if no text content
if not document_text.strip():
logger.debug("No text content found, skipping metadata extraction")
return {"categories": [], "entities": []}
# Truncate very long documents to avoid token limits
max_chars = 50000 # Roughly 12k tokens
if len(document_text) > max_chars:
document_text = document_text[:max_chars] + "..."
logger.debug(f"Truncated document text to {max_chars} characters")
response = client.responses.parse(
model="o3",
input=[
{
"role": "system",
"content": "Extract categories and entities from the document.",
},
{
"role": "user",
"content": prompt + "\n\nDOCUMENT: " + document_text,
},
],
text_format=DocumentClassificationResult,
)
classification_result = response.output_parsed
result = {
"categories": classification_result.categories,
"entities": classification_result.entities,
}
logger.debug(f"Successfully extracted metadata: {result}")
return result
except Exception as e:
attempt_num = attempt + 1
is_last_attempt = attempt == max_retries
# Log the error
if is_last_attempt:
logger.error(
f"Failed to extract categories and entities after {max_retries + 1} attempts: {e}"
)
else:
logger.warning(
f"Attempt {attempt_num} failed to extract metadata: {e}. Retrying..."
)
# If this is the last attempt, return empty results
if is_last_attempt:
return {"categories": [], "entities": []}
# Calculate delay with exponential backoff and jitter
delay = base_delay * (backoff_factor**attempt)
jitter = random.uniform(0.1, 0.3) # Add 10-30% jitter
total_delay = delay + jitter
logger.debug(
f"Waiting {total_delay:.2f} seconds before retry {attempt_num + 1}"
)
time.sleep(total_delay)
# Should never reach here, but just in case
return {"categories": [], "entities": []}
def _convert_drive_item_to_document(
creds: Any,
allow_images: bool,
@@ -499,17 +629,23 @@ def _convert_drive_item_to_document(
else None
)
# Extract categories and entities from drive item and store in metadata
categories_and_entities = _extract_categories_and_entities(sections)
metadata = {
"owner_names": ", ".join(
owner.get("displayName", "") for owner in file.get("owners", [])
),
"categories": categories_and_entities.get("categories", []),
"entities": categories_and_entities.get("entities", []),
}
# Create the document
return Document(
id=doc_id,
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file.get("name", ""),
metadata={
"owner_names": ", ".join(
owner.get("displayName", "") for owner in file.get("owners", [])
),
},
metadata=metadata,
doc_updated_at=datetime.fromisoformat(
file.get("modifiedTime", "").replace("Z", "+00:00")
),

View File

@@ -309,10 +309,30 @@ def docx_to_text_and_images(
paragraphs = []
embedded_images: list[tuple[bytes, str]] = []
# Debug: Check file properties before processing
file.seek(0)
first_bytes = file.read(100) # Read first 100 bytes
file.seek(0) # Reset position
logger.debug(f"Processing file: {file_name}")
logger.debug(f"File size: {getattr(file, 'size', 'unknown')}")
logger.debug(f"First 100 bytes: {first_bytes}")
logger.debug(
f"File type check - starts with PK (ZIP): {first_bytes.startswith(b'PK')}"
)
try:
doc = docx.Document(file)
except BadZipFile as e:
logger.warning(f"Failed to extract text from {file_name or 'docx file'}: {e}")
logger.error(f"BadZipFile error for {file_name}: {e}")
logger.error(f"File first bytes: {first_bytes}")
logger.error(f"Is this actually a ZIP file? {first_bytes.startswith(b'PK')}")
return "", []
except Exception as e:
logger.error(
f"Unexpected error processing DOCX file {file_name}: {type(e).__name__}: {e}"
)
logger.error(f"File first bytes: {first_bytes}")
return "", []
# Grab text from paragraphs
@@ -523,7 +543,7 @@ def extract_text_and_images(
# docx example for embedded images
if extension == ".docx":
file.seek(0)
text_content, images = docx_to_text_and_images(file)
text_content, images = docx_to_text_and_images(file, file_name=file_name)
return ExtractionResult(
text_content=text_content, embedded_images=images, metadata={}
)

View File

@@ -1,7 +1,10 @@
import copy
import csv
import json
import os
from collections.abc import Callable
from collections.abc import Generator
from pathlib import Path
from typing import Any
from typing import cast
from typing import TypeVar
@@ -19,6 +22,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
from onyx.chat.prune_and_merge import prune_and_merge_sections
from onyx.chat.prune_and_merge import prune_sections
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
@@ -62,6 +66,39 @@ SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
def _append_ranking_stats_to_csv(
llm_doc_results: list[tuple[int, str]],
query: str,
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
) -> None:
"""
Append ranking statistics to a CSV file.
Args:
ranking_stats: List of tuples containing (query, hit_position, document_id)
csv_path: Path to the CSV file to append to
"""
file_exists = os.path.isfile(csv_path)
# Create directory if it doesn't exist
csv_dir = os.path.dirname(csv_path)
if csv_dir and not os.path.exists(csv_dir):
Path(csv_dir).mkdir(parents=True, exist_ok=True)
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
writer = csv.writer(file)
# Write header if file is new
if not file_exists:
writer.writerow(["query", "position", "document_id", "answer", "score"])
# Write the ranking stats
for pos, doc in llm_doc_results:
writer.writerow([query, pos, doc, ""])
logger.debug(f"Appended {len(llm_doc_results)} ranking stats to {csv_path}")
class SearchResponseSummary(SearchQueryInfo):
top_sections: list[InferenceSection]
rephrased_query: str | None = None
@@ -303,6 +340,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
kg_sources = None
kg_chunk_id_zero_only = False
if override_kwargs:
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
precomputed_keywords = override_kwargs.precomputed_keywords
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
@@ -499,6 +539,12 @@ def yield_search_responses(
)
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
# Append ranking statistics to a CSV file
llm_doc_results = []
for pos, doc in enumerate(llm_docs):
llm_doc_results.append((pos, doc.document_id))
# _append_ranking_stats_to_csv(llm_doc_results, query)
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)