mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-26 01:52:45 +00:00
Compare commits
8 Commits
v3.0.4
...
hackathon/
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d973363c5c | ||
|
|
013bed3157 | ||
|
|
289f27c43a | ||
|
|
736a9bd332 | ||
|
|
8bcad415bb | ||
|
|
93e6e4a089 | ||
|
|
ed0062dce0 | ||
|
|
6e8bf3120c |
@@ -51,6 +51,7 @@ def _create_history_str(prompt_builder: AnswerPromptBuilder) -> str:
|
||||
else:
|
||||
continue
|
||||
history_segments.append(f"{role}:\n {msg.content}\n\n")
|
||||
|
||||
return "\n".join(history_segments)
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,12 @@
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
@@ -11,6 +16,9 @@ from onyx.agents.agent_search.models import GraphSearchConfig
|
||||
from onyx.agents.agent_search.models import GraphTooling
|
||||
from onyx.agents.agent_search.run_graph import run_agent_search_graph
|
||||
from onyx.agents.agent_search.run_graph import run_basic_graph
|
||||
from onyx.agents.agent_search.run_graph import (
|
||||
run_basic_graph as run_hackathon_graph,
|
||||
) # You can create your own graph
|
||||
from onyx.agents.agent_search.run_graph import run_dc_graph
|
||||
from onyx.agents.agent_search.run_graph import run_kb_graph
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
@@ -22,9 +30,11 @@ from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import SubQuestionKey
|
||||
from onyx.chat.models import ToolCallFinalResult
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_ALLOW_REFINEMENT
|
||||
from onyx.configs.agent_configs import INITIAL_SEARCH_DECOMPOSITION_ENABLED
|
||||
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
|
||||
from onyx.configs.chat_configs import USE_DIV_CON_AGENT
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
@@ -44,6 +54,190 @@ logger = setup_logger()
|
||||
BASIC_SQ_KEY = SubQuestionKey(level=BASIC_KEY[0], question_num=BASIC_KEY[1])
|
||||
|
||||
|
||||
def _calc_score_for_pos(pos: int, max_acceptable_pos: int = 15) -> float:
|
||||
"""
|
||||
Calculate the score for a given position.
|
||||
"""
|
||||
if pos > max_acceptable_pos:
|
||||
return 0
|
||||
|
||||
elif pos == 1:
|
||||
return 1
|
||||
elif pos == 2:
|
||||
return 0.8
|
||||
else:
|
||||
return 4 / (pos + 5)
|
||||
|
||||
|
||||
def _clean_doc_id_link(doc_link: str) -> str:
|
||||
"""
|
||||
Clean the google doc link.
|
||||
"""
|
||||
if "google.com" in doc_link:
|
||||
if "/edit" in doc_link:
|
||||
return "/edit".join(doc_link.split("/edit")[:-1])
|
||||
elif "/view" in doc_link:
|
||||
return "/view".join(doc_link.split("/view")[:-1])
|
||||
else:
|
||||
return doc_link
|
||||
|
||||
if "app.fireflies.ai" in doc_link:
|
||||
return "?".join(doc_link.split("?")[:-1])
|
||||
return doc_link
|
||||
|
||||
|
||||
def _get_doc_score(doc_id: str, doc_results: list[str]) -> float:
|
||||
"""
|
||||
Get the score of a document from the document results.
|
||||
"""
|
||||
|
||||
match_pos = None
|
||||
for pos, comp_doc in enumerate(doc_results, start=1):
|
||||
|
||||
clear_doc_id = _clean_doc_id_link(doc_id)
|
||||
clear_comp_doc = _clean_doc_id_link(comp_doc)
|
||||
|
||||
if clear_doc_id == clear_comp_doc:
|
||||
match_pos = pos
|
||||
|
||||
if match_pos is None:
|
||||
return 0.0
|
||||
|
||||
return _calc_score_for_pos(match_pos)
|
||||
|
||||
|
||||
def _append_empty_line(csv_path: str = HACKATHON_OUTPUT_CSV_PATH):
|
||||
"""
|
||||
Append an empty line to the CSV file.
|
||||
"""
|
||||
_append_answer_to_csv("", "", csv_path)
|
||||
|
||||
|
||||
def _append_ground_truth_to_csv(
|
||||
query: str,
|
||||
ground_truth_docs: list[str],
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the score to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
for doc_id in ground_truth_docs:
|
||||
writer.writerow([query, "-1", _clean_doc_id_link(doc_id), "", ""])
|
||||
|
||||
logger.debug("Appended score to csv file")
|
||||
|
||||
|
||||
def _append_score_to_csv(
|
||||
query: str,
|
||||
score: float,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the score to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
writer.writerow([query, "", "", "", score])
|
||||
|
||||
logger.debug("Appended score to csv file")
|
||||
|
||||
|
||||
def _append_search_results_to_csv(
|
||||
query: str,
|
||||
doc_results: list[str],
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append the search results to the CSV file.
|
||||
"""
|
||||
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
for pos, doc in enumerate(doc_results, start=1):
|
||||
writer.writerow([query, pos, _clean_doc_id_link(doc), "", ""])
|
||||
|
||||
logger.debug("Appended search results to csv file")
|
||||
|
||||
|
||||
def _append_answer_to_csv(
|
||||
query: str,
|
||||
answer: str,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append ranking statistics to a CSV file.
|
||||
|
||||
Args:
|
||||
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||
csv_path: Path to the CSV file to append to
|
||||
"""
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
|
||||
writer.writerow([query, "", "", answer, ""])
|
||||
|
||||
logger.debug("Appended answer to csv file")
|
||||
|
||||
|
||||
class Answer:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -134,6 +328,9 @@ class Answer:
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
|
||||
_HACKATHON_TEST_EXECUTION = False
|
||||
|
||||
if self._processed_stream is not None:
|
||||
yield from self._processed_stream
|
||||
return
|
||||
@@ -154,22 +351,124 @@ class Answer:
|
||||
)
|
||||
):
|
||||
run_langgraph = run_dc_graph
|
||||
|
||||
elif (
|
||||
self.graph_config.inputs.persona
|
||||
and self.graph_config.inputs.persona.description.startswith(
|
||||
"Hackathon Test"
|
||||
)
|
||||
):
|
||||
_HACKATHON_TEST_EXECUTION = True
|
||||
run_langgraph = run_hackathon_graph
|
||||
|
||||
else:
|
||||
run_langgraph = run_basic_graph
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
if _HACKATHON_TEST_EXECUTION:
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
# Enable search-only mode for hackathon test execution
|
||||
self.graph_config.behavior.skip_gen_ai_answer_generation = True
|
||||
# Disable reranking for faster processing
|
||||
self.graph_config.behavior.allow_agent_reranking = False
|
||||
self.graph_config.behavior.allow_refinement = False
|
||||
|
||||
# Disable query rewording for more predictable results
|
||||
input_data = str(self.graph_config.inputs.prompt_builder.raw_user_query)
|
||||
|
||||
if input_data.startswith("["):
|
||||
input_type = "json"
|
||||
input_list = json.loads(input_data)
|
||||
else:
|
||||
input_type = "list"
|
||||
input_list = input_data.split(";")
|
||||
|
||||
num_examples_with_ground_truth = 0
|
||||
total_score = 0.0
|
||||
|
||||
for question_num, question_data in enumerate(input_list):
|
||||
|
||||
ground_truth_docs = None
|
||||
if input_type == "json":
|
||||
question = question_data["question"]
|
||||
ground_truth = question_data.get("ground_truth")
|
||||
if ground_truth:
|
||||
ground_truth_docs = [x.get("doc_link") for x in ground_truth]
|
||||
logger.info(f"Question {question_num}: {question}")
|
||||
_append_ground_truth_to_csv(question, ground_truth_docs)
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
question = question_data
|
||||
|
||||
self.graph_config.inputs.prompt_builder.raw_user_query = question
|
||||
self.graph_config.inputs.prompt_builder.user_message_and_token_cnt = (
|
||||
HumanMessage(
|
||||
content=question, additional_kwargs={}, response_metadata={}
|
||||
),
|
||||
2,
|
||||
)
|
||||
self.graph_config.tooling.force_use_tool.force_use = True
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
|
||||
llm_answer_segments: list[str] = []
|
||||
doc_results: list[str] | None = None
|
||||
for answer_piece in processed_stream:
|
||||
if isinstance(answer_piece, OnyxAnswerPiece):
|
||||
llm_answer_segments.append(answer_piece.answer_piece or "")
|
||||
elif isinstance(answer_piece, ToolCallFinalResult):
|
||||
doc_results = [x.get("link") for x in answer_piece.tool_result]
|
||||
|
||||
if doc_results:
|
||||
_append_search_results_to_csv(question, doc_results)
|
||||
|
||||
_append_answer_to_csv(question, "".join(llm_answer_segments))
|
||||
|
||||
if ground_truth_docs and doc_results:
|
||||
num_examples_with_ground_truth += 1
|
||||
doc_score = 0.0
|
||||
for doc_id in ground_truth_docs:
|
||||
doc_score += _get_doc_score(doc_id, doc_results)
|
||||
|
||||
_append_score_to_csv(question, doc_score)
|
||||
total_score += doc_score
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
if num_examples_with_ground_truth > 0:
|
||||
comprehensive_score = total_score / num_examples_with_ground_truth
|
||||
else:
|
||||
comprehensive_score = 0
|
||||
|
||||
_append_empty_line()
|
||||
|
||||
_append_score_to_csv(question, comprehensive_score)
|
||||
|
||||
else:
|
||||
|
||||
stream = run_langgraph(
|
||||
self.graph_config,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for packet in stream:
|
||||
if self.is_cancelled():
|
||||
packet = StreamStopInfo(stop_reason=StreamStopReason.CANCELLED)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
break
|
||||
processed_stream.append(packet)
|
||||
yield packet
|
||||
self._processed_stream = processed_stream
|
||||
self._processed_stream = processed_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
|
||||
@@ -1012,6 +1012,7 @@ def stream_chat_message_objects(
|
||||
tools=tools,
|
||||
db_session=db_session,
|
||||
use_agentic_search=new_msg_req.use_agentic_search,
|
||||
skip_gen_ai_answer_generation=new_msg_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
|
||||
info_by_subq: dict[SubQuestionKey, AnswerPostInfo] = defaultdict(
|
||||
|
||||
@@ -187,8 +187,12 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
if (
|
||||
self.new_messages_and_token_cnts
|
||||
and isinstance(self.user_message_and_token_cnt[0].content, str)
|
||||
and self.user_message_and_token_cnt[0].content.startswith("Refer")
|
||||
):
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts[-2:])
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -787,3 +787,7 @@ S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY")
|
||||
# Forcing Vespa Language
|
||||
# English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html
|
||||
VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE")
|
||||
|
||||
HACKATHON_OUTPUT_CSV_PATH = os.environ.get(
|
||||
"HACKATHON_OUTPUT_CSV_PATH", "/tmp/hackathon_output.csv"
|
||||
)
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
import openai
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -48,6 +51,133 @@ _FIREFLIES_API_QUERY = """
|
||||
ONE_MINUTE = 60
|
||||
|
||||
|
||||
class DocumentClassificationResult(BaseModel):
|
||||
categories: list[str]
|
||||
entities: list[str]
|
||||
|
||||
|
||||
def _extract_categories_and_entities(
|
||||
sections: list[TextSection | ImageSection],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Extract categories and entities from document sections with retry logic."""
|
||||
import time
|
||||
import random
|
||||
|
||||
prompt = """
|
||||
Analyze this document, classify it with categories, and extract important entities.
|
||||
|
||||
CATEGORIES:
|
||||
Create up to 5 simple categories that best capture what this document is about. Consider categories within:
|
||||
- Document type (e.g., Manual, Report, Email, Transcript, etc.)
|
||||
- Content domain (e.g., Technical, Financial, HR, Marketing, etc.)
|
||||
- Purpose (e.g., Training, Reference, Announcement, Analysis, etc.)
|
||||
- Industry/Topic area (e.g., Software Development, Sales, Legal, etc.)
|
||||
|
||||
Be creative and specific. Use clear, descriptive terms that someone searching for this document might use.
|
||||
Categories should be up to 2 words each.
|
||||
|
||||
ENTITIES:
|
||||
Extract up to 5 important proper nouns, such as:
|
||||
- Company names (e.g., Microsoft, Google, Acme Corp)
|
||||
- Product names (e.g., Office 365, Salesforce, iPhone)
|
||||
- People's names (e.g. John, Jane, Ahmed, Wenjie, etc.)
|
||||
- Department names (e.g., Engineering, Marketing, HR)
|
||||
- Project names (e.g., Project Alpha, Migration 2024)
|
||||
- Technology names (e.g., PostgreSQL, React, AWS)
|
||||
- Location names (e.g., New York Office, Building A)
|
||||
"""
|
||||
|
||||
# Retry configuration
|
||||
max_retries = 3
|
||||
base_delay = 1.0 # seconds
|
||||
backoff_factor = 2.0
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set, skipping metadata extraction")
|
||||
return {"categories": [], "entities": []}
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
# Combine all section text
|
||||
document_text = "\n\n".join(
|
||||
[
|
||||
section.text
|
||||
for section in sections
|
||||
if isinstance(section, TextSection) and section.text.strip()
|
||||
]
|
||||
)
|
||||
|
||||
# Skip if no text content
|
||||
if not document_text.strip():
|
||||
logger.debug("No text content found, skipping metadata extraction")
|
||||
return {"categories": [], "entities": []}
|
||||
|
||||
# Truncate very long documents to avoid token limits
|
||||
max_chars = 50000 # Roughly 12k tokens
|
||||
if len(document_text) > max_chars:
|
||||
document_text = document_text[:max_chars] + "..."
|
||||
logger.debug(f"Truncated document text to {max_chars} characters")
|
||||
|
||||
response = client.responses.parse(
|
||||
model="o3",
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Extract categories and entities from the document.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt + "\n\nDOCUMENT: " + document_text,
|
||||
},
|
||||
],
|
||||
text_format=DocumentClassificationResult,
|
||||
)
|
||||
|
||||
classification_result = response.output_parsed
|
||||
|
||||
result = {
|
||||
"categories": classification_result.categories,
|
||||
"entities": classification_result.entities,
|
||||
}
|
||||
|
||||
logger.debug(f"Successfully extracted metadata: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
attempt_num = attempt + 1
|
||||
is_last_attempt = attempt == max_retries
|
||||
|
||||
# Log the error
|
||||
if is_last_attempt:
|
||||
logger.error(
|
||||
f"Failed to extract categories and entities after {max_retries + 1} attempts: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Attempt {attempt_num} failed to extract metadata: {e}. Retrying..."
|
||||
)
|
||||
|
||||
# If this is the last attempt, return empty results
|
||||
if is_last_attempt:
|
||||
return {"categories": [], "entities": []}
|
||||
|
||||
# Calculate delay with exponential backoff and jitter
|
||||
delay = base_delay * (backoff_factor**attempt)
|
||||
jitter = random.uniform(0.1, 0.3) # Add 10-30% jitter
|
||||
total_delay = delay + jitter
|
||||
|
||||
logger.debug(
|
||||
f"Waiting {total_delay:.2f} seconds before retry {attempt_num + 1}"
|
||||
)
|
||||
time.sleep(total_delay)
|
||||
|
||||
# Should never reach here, but just in case
|
||||
return {"categories": [], "entities": []}
|
||||
|
||||
|
||||
def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
sections: List[TextSection] = []
|
||||
current_speaker_name = None
|
||||
@@ -96,12 +226,19 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
if participant != meeting_organizer_email and participant:
|
||||
meeting_participants_email_list.append(BasicExpertInfo(email=participant))
|
||||
|
||||
# Extract categories and entities from transcript and store in metadata
|
||||
categories_and_entities = _extract_categories_and_entities(sections)
|
||||
metadata = {
|
||||
"categories": categories_and_entities.get("categories", []),
|
||||
"entities": categories_and_entities.get("entities", []),
|
||||
}
|
||||
|
||||
return Document(
|
||||
id=fireflies_id,
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
metadata={},
|
||||
metadata=metadata,
|
||||
doc_updated_at=meeting_date,
|
||||
primary_owners=organizer_email_user_info,
|
||||
secondary_owners=meeting_participants_email_list,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import openai
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
from googleapiclient.http import MediaIoBaseDownload # type: ignore
|
||||
from pydantic import BaseModel
|
||||
@@ -45,6 +47,12 @@ from onyx.utils.variable_functionality import noop_fallback
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class DocumentClassificationResult(BaseModel):
|
||||
categories: list[str]
|
||||
entities: list[str]
|
||||
|
||||
|
||||
# This is not a standard valid unicode char, it is used by the docs advanced API to
|
||||
# represent smart chips (elements like dates and doc links).
|
||||
SMART_CHIP_CHAR = "\ue907"
|
||||
@@ -406,6 +414,128 @@ def convert_drive_item_to_document(
|
||||
return first_error
|
||||
|
||||
|
||||
def _extract_categories_and_entities(
|
||||
sections: list[TextSection | ImageSection],
|
||||
) -> dict[str, list[str]]:
|
||||
"""Extract categories and entities from document sections with retry logic."""
|
||||
import time
|
||||
import random
|
||||
|
||||
prompt = """
|
||||
Analyze this document, classify it with categories, and extract important entities.
|
||||
|
||||
CATEGORIES:
|
||||
Create up to 5 simple categories that best capture what this document is about. Consider categories within:
|
||||
- Document type (e.g., Manual, Report, Email, Transcript, etc.)
|
||||
- Content domain (e.g., Technical, Financial, HR, Marketing, etc.)
|
||||
- Purpose (e.g., Training, Reference, Announcement, Analysis, etc.)
|
||||
- Industry/Topic area (e.g., Software Development, Sales, Legal, etc.)
|
||||
|
||||
Be creative and specific. Use clear, descriptive terms that someone searching for this document might use.
|
||||
Categories should be up to 2 words each.
|
||||
|
||||
ENTITIES:
|
||||
Extract up to 5 important proper nouns, such as:
|
||||
- Company names (e.g., Microsoft, Google, Acme Corp)
|
||||
- Product names (e.g., Office 365, Salesforce, iPhone)
|
||||
- People's names (e.g. John, Jane, Ahmed, Wenjie, etc.)
|
||||
- Department names (e.g., Engineering, Marketing, HR)
|
||||
- Project names (e.g., Project Alpha, Migration 2024)
|
||||
- Technology names (e.g., PostgreSQL, React, AWS)
|
||||
- Location names (e.g., New York Office, Building A)
|
||||
"""
|
||||
|
||||
# Retry configuration
|
||||
max_retries = 3
|
||||
base_delay = 1.0 # seconds
|
||||
backoff_factor = 2.0
|
||||
|
||||
for attempt in range(max_retries + 1):
|
||||
try:
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
if not api_key:
|
||||
logger.warning("OPENAI_API_KEY not set, skipping metadata extraction")
|
||||
return {"categories": [], "entities": []}
|
||||
|
||||
client = openai.OpenAI(api_key=api_key)
|
||||
|
||||
# Combine all section text
|
||||
document_text = "\n\n".join(
|
||||
[
|
||||
section.text
|
||||
for section in sections
|
||||
if isinstance(section, TextSection) and section.text.strip()
|
||||
]
|
||||
)
|
||||
|
||||
# Skip if no text content
|
||||
if not document_text.strip():
|
||||
logger.debug("No text content found, skipping metadata extraction")
|
||||
return {"categories": [], "entities": []}
|
||||
|
||||
# Truncate very long documents to avoid token limits
|
||||
max_chars = 50000 # Roughly 12k tokens
|
||||
if len(document_text) > max_chars:
|
||||
document_text = document_text[:max_chars] + "..."
|
||||
logger.debug(f"Truncated document text to {max_chars} characters")
|
||||
|
||||
response = client.responses.parse(
|
||||
model="o3",
|
||||
input=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": "Extract categories and entities from the document.",
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": prompt + "\n\nDOCUMENT: " + document_text,
|
||||
},
|
||||
],
|
||||
text_format=DocumentClassificationResult,
|
||||
)
|
||||
|
||||
classification_result = response.output_parsed
|
||||
|
||||
result = {
|
||||
"categories": classification_result.categories,
|
||||
"entities": classification_result.entities,
|
||||
}
|
||||
|
||||
logger.debug(f"Successfully extracted metadata: {result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
attempt_num = attempt + 1
|
||||
is_last_attempt = attempt == max_retries
|
||||
|
||||
# Log the error
|
||||
if is_last_attempt:
|
||||
logger.error(
|
||||
f"Failed to extract categories and entities after {max_retries + 1} attempts: {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Attempt {attempt_num} failed to extract metadata: {e}. Retrying..."
|
||||
)
|
||||
|
||||
# If this is the last attempt, return empty results
|
||||
if is_last_attempt:
|
||||
return {"categories": [], "entities": []}
|
||||
|
||||
# Calculate delay with exponential backoff and jitter
|
||||
delay = base_delay * (backoff_factor**attempt)
|
||||
jitter = random.uniform(0.1, 0.3) # Add 10-30% jitter
|
||||
total_delay = delay + jitter
|
||||
|
||||
logger.debug(
|
||||
f"Waiting {total_delay:.2f} seconds before retry {attempt_num + 1}"
|
||||
)
|
||||
time.sleep(total_delay)
|
||||
|
||||
# Should never reach here, but just in case
|
||||
return {"categories": [], "entities": []}
|
||||
|
||||
|
||||
def _convert_drive_item_to_document(
|
||||
creds: Any,
|
||||
allow_images: bool,
|
||||
@@ -499,17 +629,23 @@ def _convert_drive_item_to_document(
|
||||
else None
|
||||
)
|
||||
|
||||
# Extract categories and entities from drive item and store in metadata
|
||||
categories_and_entities = _extract_categories_and_entities(sections)
|
||||
metadata = {
|
||||
"owner_names": ", ".join(
|
||||
owner.get("displayName", "") for owner in file.get("owners", [])
|
||||
),
|
||||
"categories": categories_and_entities.get("categories", []),
|
||||
"entities": categories_and_entities.get("entities", []),
|
||||
}
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file.get("name", ""),
|
||||
metadata={
|
||||
"owner_names": ", ".join(
|
||||
owner.get("displayName", "") for owner in file.get("owners", [])
|
||||
),
|
||||
},
|
||||
metadata=metadata,
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
file.get("modifiedTime", "").replace("Z", "+00:00")
|
||||
),
|
||||
|
||||
@@ -309,10 +309,30 @@ def docx_to_text_and_images(
|
||||
paragraphs = []
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
# Debug: Check file properties before processing
|
||||
file.seek(0)
|
||||
first_bytes = file.read(100) # Read first 100 bytes
|
||||
file.seek(0) # Reset position
|
||||
|
||||
logger.debug(f"Processing file: {file_name}")
|
||||
logger.debug(f"File size: {getattr(file, 'size', 'unknown')}")
|
||||
logger.debug(f"First 100 bytes: {first_bytes}")
|
||||
logger.debug(
|
||||
f"File type check - starts with PK (ZIP): {first_bytes.startswith(b'PK')}"
|
||||
)
|
||||
|
||||
try:
|
||||
doc = docx.Document(file)
|
||||
except BadZipFile as e:
|
||||
logger.warning(f"Failed to extract text from {file_name or 'docx file'}: {e}")
|
||||
logger.error(f"BadZipFile error for {file_name}: {e}")
|
||||
logger.error(f"File first bytes: {first_bytes}")
|
||||
logger.error(f"Is this actually a ZIP file? {first_bytes.startswith(b'PK')}")
|
||||
return "", []
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Unexpected error processing DOCX file {file_name}: {type(e).__name__}: {e}"
|
||||
)
|
||||
logger.error(f"File first bytes: {first_bytes}")
|
||||
return "", []
|
||||
|
||||
# Grab text from paragraphs
|
||||
@@ -523,7 +543,7 @@ def extract_text_and_images(
|
||||
# docx example for embedded images
|
||||
if extension == ".docx":
|
||||
file.seek(0)
|
||||
text_content, images = docx_to_text_and_images(file)
|
||||
text_content, images = docx_to_text_and_images(file, file_name=file_name)
|
||||
return ExtractionResult(
|
||||
text_content=text_content, embedded_images=images, metadata={}
|
||||
)
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import copy
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
@@ -19,6 +22,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from onyx.chat.prune_and_merge import prune_and_merge_sections
|
||||
from onyx.chat.prune_and_merge import prune_sections
|
||||
from onyx.configs.app_configs import HACKATHON_OUTPUT_CSV_PATH
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -62,6 +66,39 @@ SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
QUERY_FIELD = "query"
|
||||
|
||||
|
||||
def _append_ranking_stats_to_csv(
|
||||
llm_doc_results: list[tuple[int, str]],
|
||||
query: str,
|
||||
csv_path: str = HACKATHON_OUTPUT_CSV_PATH,
|
||||
) -> None:
|
||||
"""
|
||||
Append ranking statistics to a CSV file.
|
||||
|
||||
Args:
|
||||
ranking_stats: List of tuples containing (query, hit_position, document_id)
|
||||
csv_path: Path to the CSV file to append to
|
||||
"""
|
||||
file_exists = os.path.isfile(csv_path)
|
||||
|
||||
# Create directory if it doesn't exist
|
||||
csv_dir = os.path.dirname(csv_path)
|
||||
if csv_dir and not os.path.exists(csv_dir):
|
||||
Path(csv_dir).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(csv_path, mode="a", newline="", encoding="utf-8") as file:
|
||||
writer = csv.writer(file)
|
||||
|
||||
# Write header if file is new
|
||||
if not file_exists:
|
||||
writer.writerow(["query", "position", "document_id", "answer", "score"])
|
||||
|
||||
# Write the ranking stats
|
||||
for pos, doc in llm_doc_results:
|
||||
writer.writerow([query, pos, doc, ""])
|
||||
|
||||
logger.debug(f"Appended {len(llm_doc_results)} ranking stats to {csv_path}")
|
||||
|
||||
|
||||
class SearchResponseSummary(SearchQueryInfo):
|
||||
top_sections: list[InferenceSection]
|
||||
rephrased_query: str | None = None
|
||||
@@ -303,6 +340,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
kg_sources = None
|
||||
kg_chunk_id_zero_only = False
|
||||
if override_kwargs:
|
||||
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
|
||||
precomputed_keywords = override_kwargs.precomputed_keywords
|
||||
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
|
||||
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
|
||||
@@ -499,6 +539,12 @@ def yield_search_responses(
|
||||
)
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
|
||||
|
||||
# Append ranking statistics to a CSV file
|
||||
llm_doc_results = []
|
||||
for pos, doc in enumerate(llm_docs):
|
||||
llm_doc_results.append((pos, doc.document_id))
|
||||
# _append_ranking_stats_to_csv(llm_doc_results, query)
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user