mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-07 16:02:45 +00:00
Compare commits
34 Commits
cli/v0.2.1
...
as_metrics
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
821b226d25 | ||
|
|
6dc81bbb7c | ||
|
|
6989441851 | ||
|
|
683978ddb0 | ||
|
|
568bc16536 | ||
|
|
0333ff648a | ||
|
|
cc76486d21 | ||
|
|
901d8c22c4 | ||
|
|
21928133e0 | ||
|
|
c4af11c19b | ||
|
|
ca3f3beabe | ||
|
|
fa481019e8 | ||
|
|
f4c826c4e5 | ||
|
|
2a3328fc3d | ||
|
|
34aa054c5d | ||
|
|
cebe237705 | ||
|
|
c759fb5709 | ||
|
|
ffc81f6e45 | ||
|
|
2d6f746259 | ||
|
|
bca02ebec6 | ||
|
|
0c75ca0579 | ||
|
|
9d3220fcfc | ||
|
|
50a216f554 | ||
|
|
8399d2ee0a | ||
|
|
fd694bea8f | ||
|
|
e76cbec53c | ||
|
|
d66180fe13 | ||
|
|
442c94727e | ||
|
|
2f2b9a862a | ||
|
|
1f88b60abd | ||
|
|
ff03d717f3 | ||
|
|
82914ad365 | ||
|
|
11ce2a62ab | ||
|
|
6311b70cc6 |
4
.gitignore
vendored
4
.gitignore
vendored
@@ -7,4 +7,6 @@
|
||||
.vscode/
|
||||
*.sw?
|
||||
/backend/tests/regression/answer_quality/search_test_config.yaml
|
||||
/web/test-results/
|
||||
/web/test-results/
|
||||
backend/onyx/agent_search/main/test_data.json
|
||||
backend/tests/regression/answer_quality/test_data.json
|
||||
|
||||
6
.vscode/env_template.txt
vendored
6
.vscode/env_template.txt
vendored
@@ -49,3 +49,9 @@ BING_API_KEY=<REPLACE THIS>
|
||||
# Enable the full set of Danswer Enterprise Edition features
|
||||
# NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you are using this for local testing/development)
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False
|
||||
|
||||
# Agent Search configs # TODO: Remove give proper namings
|
||||
AGENT_RETRIEVAL_STATS=False # Note: This setting will incur substantial re-ranking effort
|
||||
AGENT_RERANKING_STATS=True
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS=20
|
||||
|
||||
20
backend/onyx/agent_search/answer_question/edges.py
Normal file
20
backend/onyx/agent_search/answer_question/edges.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionInput
|
||||
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
|
||||
|
||||
def send_to_expanded_retrieval(state: AnswerQuestionInput) -> Send | Hashable:
|
||||
print("sending to expanded retrieval via edge")
|
||||
|
||||
return Send(
|
||||
"decomped_expanded_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
**in_subgraph_extract_core_fields(state),
|
||||
question=state["question"],
|
||||
dummy="1"
|
||||
),
|
||||
)
|
||||
106
backend/onyx/agent_search/answer_question/graph_builder.py
Normal file
106
backend/onyx/agent_search/answer_question/graph_builder.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.answer_question.edges import send_to_expanded_retrieval
|
||||
from onyx.agent_search.answer_question.nodes.answer_check import answer_check
|
||||
from onyx.agent_search.answer_question.nodes.answer_generation import answer_generation
|
||||
from onyx.agent_search.answer_question.nodes.format_answer import format_answer
|
||||
from onyx.agent_search.answer_question.nodes.ingest_retrieval import ingest_retrieval
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionInput
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionOutput
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionState
|
||||
from onyx.agent_search.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQuestionState,
|
||||
input=AnswerQuestionInput,
|
||||
output=AnswerQuestionOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="decomped_expanded_retrieval",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_retrieval",
|
||||
action=ingest_retrieval,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=send_to_expanded_retrieval,
|
||||
path_map=["decomped_expanded_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="decomped_expanded_retrieval",
|
||||
end_key="ingest_retrieval",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="ingest_retrieval",
|
||||
end_key="answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_generation",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = AnswerQuestionInput(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
question="what can you do with onyx?",
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
):
|
||||
print(thing)
|
||||
# output = compiled_graph.invoke(inputs)
|
||||
# print(output)
|
||||
19
backend/onyx/agent_search/answer_question/models.py
Normal file
19
backend/onyx/agent_search/answer_question/models.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
# expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,30 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionState
|
||||
from onyx.agent_search.answer_question.states import QACheckUpdate
|
||||
from onyx.agent_search.shared_graph_utils.prompts import SUB_CHECK_PROMPT
|
||||
|
||||
|
||||
def answer_check(state: AnswerQuestionState) -> QACheckUpdate:
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
question=state["question"],
|
||||
base_answer=state["answer"],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
fast_llm = state["subgraph_fast_llm"]
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
quality_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
return QACheckUpdate(
|
||||
answer_quality=quality_str,
|
||||
)
|
||||
@@ -0,0 +1,36 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionState
|
||||
from onyx.agent_search.answer_question.states import QAGenerationUpdate
|
||||
from onyx.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from onyx.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def answer_generation(state: AnswerQuestionState) -> QAGenerationUpdate:
|
||||
question = state["question"]
|
||||
docs = state["documents"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(
|
||||
question=question,
|
||||
context=format_docs(docs),
|
||||
original_question=state["subgraph_search_request"].query,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
fast_llm = state["subgraph_fast_llm"]
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
return QAGenerationUpdate(
|
||||
answer=answer_str,
|
||||
)
|
||||
@@ -0,0 +1,28 @@
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionOutput
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionState
|
||||
from onyx.agent_search.answer_question.states import QuestionAnswerResults
|
||||
|
||||
|
||||
def format_answer(state: AnswerQuestionState) -> AnswerQuestionOutput:
|
||||
# sub_question_retrieval_stats = state["sub_question_retrieval_stats"]
|
||||
# if sub_question_retrieval_stats is None:
|
||||
# sub_question_retrieval_stats = []
|
||||
# elif isinstance(sub_question_retrieval_stats, list):
|
||||
# sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
# if isinstance(sub_question_retrieval_stats[0], list):
|
||||
# sub_question_retrieval_stats = sub_question_retrieval_stats[0]
|
||||
# else:
|
||||
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
|
||||
|
||||
return AnswerQuestionOutput(
|
||||
answer_results=[
|
||||
QuestionAnswerResults(
|
||||
question=state["question"],
|
||||
quality=state["answer_quality"],
|
||||
answer=state["answer"],
|
||||
# expanded_retrieval_results=state["expanded_retrieval_results"],
|
||||
documents=state["documents"],
|
||||
sub_question_retrieval_stats=state["sub_question_retrieval_stats"],
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -0,0 +1,19 @@
|
||||
from onyx.agent_search.answer_question.states import RetrievalIngestionUpdate
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_retrieval(state: ExpandedRetrievalOutput) -> RetrievalIngestionUpdate:
|
||||
sub_question_retrieval_stats = state[
|
||||
"expanded_retrieval_result"
|
||||
].sub_question_retrieval_stats
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
|
||||
return RetrievalIngestionUpdate(
|
||||
expanded_retrieval_results=state[
|
||||
"expanded_retrieval_result"
|
||||
].expanded_queries_results,
|
||||
documents=state["expanded_retrieval_result"].all_documents,
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
58
backend/onyx/agent_search/answer_question/states.py
Normal file
58
backend/onyx/agent_search/answer_question/states.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agent_search.answer_question.models import QuestionAnswerResults
|
||||
from onyx.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agent_search.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
## Update States
|
||||
class QACheckUpdate(TypedDict):
|
||||
answer_quality: str
|
||||
|
||||
|
||||
class QAGenerationUpdate(TypedDict):
|
||||
answer: str
|
||||
# answer_stat: AnswerStats
|
||||
|
||||
|
||||
class RetrievalIngestionUpdate(TypedDict):
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class AnswerQuestionInput(SubgraphCoreState):
|
||||
question: str
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class AnswerQuestionState(
|
||||
AnswerQuestionInput,
|
||||
QAGenerationUpdate,
|
||||
QACheckUpdate,
|
||||
RetrievalIngestionUpdate,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class AnswerQuestionOutput(TypedDict):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
answer_results: Annotated[list[QuestionAnswerResults], add]
|
||||
70
backend/onyx/agent_search/base_raw_search/graph_builder.py
Normal file
70
backend/onyx/agent_search/base_raw_search/graph_builder.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.base_raw_search.nodes.format_raw_search_results import (
|
||||
format_raw_search_results,
|
||||
)
|
||||
from onyx.agent_search.base_raw_search.nodes.generate_raw_search_data import (
|
||||
generate_raw_search_data,
|
||||
)
|
||||
from onyx.agent_search.base_raw_search.states import BaseRawSearchInput
|
||||
from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput
|
||||
from onyx.agent_search.base_raw_search.states import BaseRawSearchState
|
||||
from onyx.agent_search.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def base_raw_search_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=BaseRawSearchState,
|
||||
input=BaseRawSearchInput,
|
||||
output=BaseRawSearchOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="generate_raw_search_data",
|
||||
action=generate_raw_search_data,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="expanded_retrieval_base_search",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_raw_search_results",
|
||||
action=format_raw_search_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(start_key=START, end_key="generate_raw_search_data")
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_raw_search_data",
|
||||
end_key="expanded_retrieval_base_search",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expanded_retrieval_base_search",
|
||||
end_key="format_raw_search_results",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="expanded_retrieval_base_search",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="format_raw_search_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
20
backend/onyx/agent_search/base_raw_search/models.py
Normal file
20
backend/onyx/agent_search/base_raw_search/models.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agent_search.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class AnswerRetrievalStats(BaseModel):
|
||||
answer_retrieval_stats: dict[str, float | int]
|
||||
|
||||
|
||||
class QuestionAnswerResults(BaseModel):
|
||||
question: str
|
||||
answer: str
|
||||
quality: str
|
||||
expanded_retrieval_results: list[QueryResult]
|
||||
documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: list[AgentChunkStats]
|
||||
@@ -0,0 +1,11 @@
|
||||
from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput
|
||||
|
||||
|
||||
def format_raw_search_results(state: ExpandedRetrievalOutput) -> BaseRawSearchOutput:
|
||||
print("format_raw_search_results")
|
||||
return BaseRawSearchOutput(
|
||||
base_expanded_retrieval_result=state["expanded_retrieval_result"],
|
||||
# base_retrieval_results=[state["expanded_retrieval_result"]],
|
||||
# base_search_documents=[],
|
||||
)
|
||||
@@ -0,0 +1,15 @@
|
||||
from onyx.agent_search.core_state import CoreState
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
|
||||
|
||||
def generate_raw_search_data(state: CoreState) -> ExpandedRetrievalInput:
|
||||
print("generate_raw_search_data")
|
||||
return ExpandedRetrievalInput(
|
||||
subgraph_search_request=state["search_request"],
|
||||
subgraph_primary_llm=state["primary_llm"],
|
||||
subgraph_fast_llm=state["fast_llm"],
|
||||
subgraph_db_session=state["db_session"],
|
||||
question=state["search_request"].query,
|
||||
dummy="7",
|
||||
base_search=True,
|
||||
)
|
||||
40
backend/onyx/agent_search/base_raw_search/states.py
Normal file
40
backend/onyx/agent_search/base_raw_search/states.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agent_search.core_state import CoreState
|
||||
from onyx.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class BaseRawSearchInput(CoreState, SubgraphCoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class BaseRawSearchOutput(TypedDict):
|
||||
"""
|
||||
This is a list of results even though each call of this subgraph only returns one result.
|
||||
This is because if we parallelize the answer query subgraph, there will be multiple
|
||||
results in a list so the add operator is used to add them together.
|
||||
"""
|
||||
|
||||
# base_search_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
# base_retrieval_results: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class BaseRawSearchState(
|
||||
BaseRawSearchInput,
|
||||
BaseRawSearchOutput,
|
||||
):
|
||||
pass
|
||||
61
backend/onyx/agent_search/core_state.py
Normal file
61
backend/onyx/agent_search/core_state.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
|
||||
|
||||
class CoreState(TypedDict, total=False):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
search_request: SearchRequest
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
# a single session for the entire agent search
|
||||
# is fine if we are only reading
|
||||
db_session: Session
|
||||
log_messages: Annotated[list[str], add]
|
||||
dummy: str
|
||||
|
||||
|
||||
class SubgraphCoreState(TypedDict, total=False):
|
||||
"""
|
||||
This is the core state that is shared across all subgraphs.
|
||||
"""
|
||||
|
||||
subgraph_search_request: SearchRequest
|
||||
subgraph_primary_llm: LLM
|
||||
subgraph_fast_llm: LLM
|
||||
# a single session for the entire agent search
|
||||
# is fine if we are only reading
|
||||
subgraph_db_session: Session
|
||||
|
||||
|
||||
# This ensures that the state passed in extends the CoreState
|
||||
T = TypeVar("T", bound=CoreState)
|
||||
T_SUBGRAPH = TypeVar("T_SUBGRAPH", bound=SubgraphCoreState)
|
||||
|
||||
|
||||
def extract_core_fields(state: T) -> CoreState:
|
||||
filtered_dict = {k: v for k, v in state.items() if k in CoreState.__annotations__}
|
||||
return CoreState(**dict(filtered_dict)) # type: ignore
|
||||
|
||||
|
||||
def extract_core_fields_for_subgraph(state: T) -> SubgraphCoreState:
|
||||
filtered_dict = {
|
||||
"subgraph_" + k: v for k, v in state.items() if k in CoreState.__annotations__
|
||||
}
|
||||
return SubgraphCoreState(**dict(filtered_dict)) # type: ignore
|
||||
|
||||
|
||||
def in_subgraph_extract_core_fields(state: T_SUBGRAPH) -> SubgraphCoreState:
|
||||
filtered_dict = {
|
||||
k: v for k, v in state.items() if k in SubgraphCoreState.__annotations__
|
||||
}
|
||||
return SubgraphCoreState(**dict(filtered_dict))
|
||||
0
backend/onyx/agent_search/deep_answer/edges.py
Normal file
0
backend/onyx/agent_search/deep_answer/edges.py
Normal file
114
backend/onyx/agent_search/deep_answer/nodes/answer_generation.py
Normal file
114
backend/onyx/agent_search/deep_answer/nodes/answer_generation.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.agent_search.main.states import MainState
|
||||
from onyx.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
|
||||
from onyx.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
|
||||
from onyx.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agent_search.shared_graph_utils.utils import normalize_whitespace
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def deep_answer_generation(state: MainState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---DEEP GENERATE---")
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
deep_answer_context = state["core_answer_dynamic_context"]
|
||||
|
||||
print(f"Number of verified retrieval docs - deep: {len(docs)}")
|
||||
|
||||
combined_context = normalize_whitespace(
|
||||
COMBINED_CONTEXT.format(
|
||||
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=MODIFIED_RAG_PROMPT.format(
|
||||
question=question, combined_context=combined_context
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
return {
|
||||
"deep_answer": response.content,
|
||||
}
|
||||
|
||||
|
||||
def final_stuff(state: MainState) -> dict[str, Any]:
|
||||
"""
|
||||
Invokes the agent model to generate a response based on the current state. Given
|
||||
the question, it will decide to retrieve using the retriever tool, or simply end.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the agent response appended to messages
|
||||
"""
|
||||
print("---FINAL---")
|
||||
|
||||
messages = state["log_messages"]
|
||||
time_ordered_messages = [x.pretty_repr() for x in messages]
|
||||
time_ordered_messages.sort()
|
||||
|
||||
print("Message Log:")
|
||||
# print("\n".join(time_ordered_messages))
|
||||
|
||||
initial_sub_qas = state["initial_sub_qas"]
|
||||
initial_sub_qa_list = []
|
||||
for initial_sub_qa in initial_sub_qas:
|
||||
if initial_sub_qa["sub_answer_check"] == "yes":
|
||||
initial_sub_qa_list.append(
|
||||
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
|
||||
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
|
||||
print("--------------------------------")
|
||||
|
||||
if not state.get("deep_answer"):
|
||||
print("No Deep Answer was required")
|
||||
return {}
|
||||
|
||||
deep_answer = state["deep_answer"]
|
||||
sub_qas = state["sub_qas"]
|
||||
sub_qa_list = []
|
||||
for sub_qa in sub_qas:
|
||||
if sub_qa["sub_answer_check"] == "yes":
|
||||
sub_qa_list.append(
|
||||
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
sub_qa_context = "\n".join(sub_qa_list)
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Final Deep Answer:\n{deep_answer}")
|
||||
print("--------------------------------")
|
||||
print("Sub Questions and Answers:")
|
||||
print(sub_qa_context)
|
||||
|
||||
return {}
|
||||
78
backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py
Normal file
78
backend/onyx/agent_search/deep_answer/nodes/deep_decomp.py
Normal file
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.agent_search.main.states import MainState
|
||||
from onyx.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
|
||||
from onyx.agent_search.shared_graph_utils.utils import format_entity_term_extraction
|
||||
from onyx.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def decompose(state: MainState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_term_extraction_dict = state["retrieved_entities_relationships"][
|
||||
"retrieved_entities_relationships"
|
||||
]
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_term_extraction_dict
|
||||
)
|
||||
|
||||
initial_question_answers = state["initial_sub_qas"]
|
||||
|
||||
addressed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "yes"
|
||||
]
|
||||
failed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "no"
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT.format(
|
||||
question=question,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
sub_questions_dict = {}
|
||||
for sub_question_nr, sub_question_dict in enumerate(
|
||||
parsed_response["sub_questions"]
|
||||
):
|
||||
sub_question_dict["answered"] = False
|
||||
sub_question_dict["verified"] = False
|
||||
sub_questions_dict[sub_question_nr] = sub_question_dict
|
||||
|
||||
return {
|
||||
"decomposed_sub_questions_dict": sub_questions_dict,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - decompose",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from onyx.agent_search.main.states import MainState
|
||||
from onyx.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||
from onyx.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def entity_term_extraction(state: MainState) -> dict[str, Any]:
|
||||
"""Extract entities and terms from the question and context"""
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
doc_context = format_docs(docs)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
fast_llm = state["fast_llm"]
|
||||
# Grader
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
return {
|
||||
"retrieved_entities_relationships": parsed_response,
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.agent_search.main.states import MainState
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def sub_qa_level_aggregator(state: MainState) -> dict[str, Any]:
|
||||
sub_qas = state["sub_qas"]
|
||||
|
||||
dynamic_context_list = [
|
||||
"Below you will find useful information to answer the original question:"
|
||||
]
|
||||
checked_sub_qas = []
|
||||
|
||||
for core_answer_sub_qa in sub_qas:
|
||||
question = core_answer_sub_qa["sub_question"]
|
||||
answer = core_answer_sub_qa["sub_answer"]
|
||||
verified = core_answer_sub_qa["sub_answer_check"]
|
||||
|
||||
if verified == "yes":
|
||||
dynamic_context_list.append(
|
||||
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
|
||||
)
|
||||
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
|
||||
dynamic_context = "\n".join(dynamic_context_list)
|
||||
|
||||
return {
|
||||
"core_answer_dynamic_context": dynamic_context,
|
||||
"checked_sub_qas": checked_sub_qas,
|
||||
}
|
||||
@@ -0,0 +1,19 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.agent_search.main.states import MainState
|
||||
|
||||
|
||||
def sub_qa_manager(state: MainState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
sub_questions_dict = state["decomposed_sub_questions_dict"]
|
||||
|
||||
sub_questions = {}
|
||||
|
||||
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
|
||||
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
|
||||
|
||||
return {
|
||||
"sub_questions": sub_questions,
|
||||
"num_new_question_iterations": 0,
|
||||
}
|
||||
0
backend/onyx/agent_search/deep_answer/states.py
Normal file
0
backend/onyx/agent_search/deep_answer/states.py
Normal file
24
backend/onyx/agent_search/expanded_retrieval/edges.py
Normal file
24
backend/onyx/agent_search/expanded_retrieval/edges.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
|
||||
from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrievalInput
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
|
||||
|
||||
def parallel_retrieval_edge(state: ExpandedRetrievalState) -> list[Send | Hashable]:
|
||||
question = state.get("question", state["subgraph_search_request"].query)
|
||||
|
||||
query_expansions = state.get("expanded_queries", []) + [question]
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
RetrievalInput(
|
||||
query_to_retrieve=query,
|
||||
question=question,
|
||||
**in_subgraph_extract_core_fields(state),
|
||||
),
|
||||
)
|
||||
for query in query_expansions
|
||||
]
|
||||
111
backend/onyx/agent_search/expanded_retrieval/graph_builder.py
Normal file
111
backend/onyx/agent_search/expanded_retrieval/graph_builder.py
Normal file
@@ -0,0 +1,111 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.expanded_retrieval.edges import parallel_retrieval_edge
|
||||
from onyx.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking
|
||||
from onyx.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval
|
||||
from onyx.agent_search.expanded_retrieval.nodes.doc_verification import (
|
||||
doc_verification,
|
||||
)
|
||||
from onyx.agent_search.expanded_retrieval.nodes.expand_queries import expand_queries
|
||||
from onyx.agent_search.expanded_retrieval.nodes.format_results import format_results
|
||||
from onyx.agent_search.expanded_retrieval.nodes.verification_kickoff import (
|
||||
verification_kickoff,
|
||||
)
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=ExpandedRetrievalState,
|
||||
input=ExpandedRetrievalInput,
|
||||
output=ExpandedRetrievalOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="expand_queries",
|
||||
action=expand_queries,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="doc_retrieval",
|
||||
action=doc_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="verification_kickoff",
|
||||
action=verification_kickoff,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_verification",
|
||||
action=doc_verification,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_reranking",
|
||||
action=doc_reranking,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_results",
|
||||
action=format_results,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expand_queries",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="expand_queries",
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_retrieval",
|
||||
end_key="verification_kickoff",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_verification",
|
||||
end_key="doc_reranking",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_reranking",
|
||||
end_key="format_results",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_results",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = expanded_retrieval_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = ExpandedRetrievalInput(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
question="what can you do with onyx?",
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
print(thing)
|
||||
19
backend/onyx/agent_search/expanded_retrieval/models.py
Normal file
19
backend/onyx/agent_search/expanded_retrieval/models.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
### Models ###
|
||||
|
||||
|
||||
class QueryResult(BaseModel):
|
||||
query: str
|
||||
search_results: list[InferenceSection]
|
||||
stats: RetrievalFitStats | None
|
||||
|
||||
|
||||
class ExpandedRetrievalResult(BaseModel):
|
||||
expanded_queries_results: list[QueryResult]
|
||||
all_documents: list[InferenceSection]
|
||||
sub_question_retrieval_stats: AgentChunkStats
|
||||
@@ -0,0 +1,47 @@
|
||||
from onyx.agent_search.expanded_retrieval.states import DocRerankingUpdate
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.configs.dev_configs import AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.dev_configs import AGENT_RERANKING_STATS
|
||||
from onyx.context.search.pipeline import InferenceSection
|
||||
from onyx.context.search.pipeline import retrieval_preprocessing
|
||||
from onyx.context.search.pipeline import search_postprocessing
|
||||
from onyx.context.search.pipeline import SearchRequest
|
||||
|
||||
|
||||
def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingUpdate:
|
||||
verified_documents = state["verified_documents"]
|
||||
|
||||
# Rerank post retrieval and verification. First, create a search query
|
||||
# then create the list of reranked sections
|
||||
|
||||
question = state.get("question", state["subgraph_search_request"].query)
|
||||
_search_query = retrieval_preprocessing(
|
||||
search_request=SearchRequest(query=question),
|
||||
user=None,
|
||||
llm=state["subgraph_fast_llm"],
|
||||
db_session=state["subgraph_db_session"],
|
||||
)
|
||||
|
||||
reranked_documents = list(
|
||||
search_postprocessing(
|
||||
search_query=_search_query,
|
||||
retrieved_sections=verified_documents,
|
||||
llm=state["subgraph_fast_llm"],
|
||||
)
|
||||
)[
|
||||
0
|
||||
] # only get the reranked szections, not the SectionRelevancePiece
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
fit_scores = RetrievalFitStats(fit_score_lift=0, rerank_effect=0, fit_scores={})
|
||||
|
||||
return DocRerankingUpdate(
|
||||
reranked_documents=[
|
||||
doc for doc in reranked_documents if type(doc) == InferenceSection
|
||||
][:AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS],
|
||||
sub_question_retrieval_stats=fit_scores,
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
from onyx.agent_search.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.expanded_retrieval.states import DocRetrievalUpdate
|
||||
from onyx.agent_search.expanded_retrieval.states import RetrievalInput
|
||||
from onyx.agent_search.shared_graph_utils.calculations import get_fit_scores
|
||||
from onyx.configs.dev_configs import AGENT_MAX_QUERY_RETRIEVAL_RESULTS
|
||||
from onyx.configs.dev_configs import AGENT_RETRIEVAL_STATS
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.pipeline import SearchPipeline
|
||||
|
||||
|
||||
def doc_retrieval(state: RetrievalInput) -> DocRetrievalUpdate:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (RetrievalInput): Primary state + the query to retrieve
|
||||
|
||||
Updates:
|
||||
expanded_retrieval_results: list[ExpandedRetrievalResult]
|
||||
retrieved_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
llm = state["subgraph_primary_llm"]
|
||||
fast_llm = state["subgraph_fast_llm"]
|
||||
query_to_retrieve = state["query_to_retrieve"]
|
||||
|
||||
search_results = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query_to_retrieve,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=state["subgraph_db_session"],
|
||||
)
|
||||
|
||||
retrieved_docs = search_results._get_sections()[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
fit_scores = get_fit_scores(
|
||||
retrieved_docs,
|
||||
search_results.reranked_sections[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS],
|
||||
)
|
||||
else:
|
||||
fit_scores = None
|
||||
|
||||
expanded_retrieval_result = QueryResult(
|
||||
query=query_to_retrieve,
|
||||
search_results=retrieved_docs,
|
||||
stats=fit_scores,
|
||||
)
|
||||
return DocRetrievalUpdate(
|
||||
expanded_retrieval_results=[expanded_retrieval_result],
|
||||
retrieved_documents=retrieved_docs,
|
||||
)
|
||||
@@ -0,0 +1,41 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.agent_search.expanded_retrieval.states import DocVerificationInput
|
||||
from onyx.agent_search.expanded_retrieval.states import DocVerificationUpdate
|
||||
from onyx.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
|
||||
|
||||
def doc_verification(state: DocVerificationInput) -> DocVerificationUpdate:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (DocVerificationInput): The current state
|
||||
|
||||
Updates:
|
||||
verified_documents: list[InferenceSection]
|
||||
"""
|
||||
|
||||
question = state["question"]
|
||||
doc_to_verify = state["doc_to_verify"]
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
fast_llm = state["subgraph_fast_llm"]
|
||||
|
||||
response = fast_llm.invoke(msg)
|
||||
|
||||
verified_documents = []
|
||||
if "yes" in response.content.lower():
|
||||
verified_documents.append(doc_to_verify)
|
||||
|
||||
return DocVerificationUpdate(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
@@ -0,0 +1,30 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
from onyx.agent_search.expanded_retrieval.states import QueryExpansionUpdate
|
||||
from onyx.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI_ORIGINAL
|
||||
from onyx.llm.interfaces import LLM
|
||||
|
||||
|
||||
def expand_queries(state: ExpandedRetrievalInput) -> QueryExpansionUpdate:
|
||||
question = state.get("question")
|
||||
llm: LLM = state["subgraph_fast_llm"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
llm_response_list = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
rewritten_queries = llm_response.split("--")
|
||||
|
||||
return QueryExpansionUpdate(
|
||||
expanded_queries=rewritten_queries,
|
||||
)
|
||||
@@ -0,0 +1,99 @@
|
||||
from collections import defaultdict
|
||||
|
||||
import numpy as np
|
||||
|
||||
from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult
|
||||
from onyx.agent_search.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
from onyx.agent_search.expanded_retrieval.states import InferenceSection
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def _calculate_sub_question_retrieval_stats(
|
||||
verified_documents: list[InferenceSection],
|
||||
expanded_retrieval_results: list[QueryResult],
|
||||
) -> AgentChunkStats:
|
||||
chunk_scores: dict[str, dict[str, list[int | float]]] = defaultdict(
|
||||
lambda: defaultdict(list)
|
||||
)
|
||||
|
||||
for expanded_retrieval_result in expanded_retrieval_results:
|
||||
for doc in expanded_retrieval_result.search_results:
|
||||
doc_chunk_id = f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
if doc.center_chunk.score is not None:
|
||||
chunk_scores[doc_chunk_id]["score"].append(doc.center_chunk.score)
|
||||
|
||||
verified_doc_chunk_ids = [
|
||||
f"{verified_document.center_chunk.document_id}_{verified_document.center_chunk.chunk_id}"
|
||||
for verified_document in verified_documents
|
||||
]
|
||||
dismissed_doc_chunk_ids = []
|
||||
|
||||
raw_chunk_stats_counts: dict[str, int] = defaultdict(int)
|
||||
raw_chunk_stats_scores: dict[str, float] = defaultdict(float)
|
||||
for doc_chunk_id, chunk_data in chunk_scores.items():
|
||||
if doc_chunk_id in verified_doc_chunk_ids:
|
||||
raw_chunk_stats_counts["verified_count"] += 1
|
||||
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["verified_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
else:
|
||||
raw_chunk_stats_counts["rejected_count"] += 1
|
||||
valid_chunk_scores = [
|
||||
score for score in chunk_data["score"] if score is not None
|
||||
]
|
||||
raw_chunk_stats_scores["rejected_scores"] += float(
|
||||
np.mean(valid_chunk_scores)
|
||||
)
|
||||
dismissed_doc_chunk_ids.append(doc_chunk_id)
|
||||
|
||||
if raw_chunk_stats_counts["verified_count"] == 0:
|
||||
verified_avg_scores = 0.0
|
||||
else:
|
||||
verified_avg_scores = raw_chunk_stats_scores["verified_scores"] / float(
|
||||
raw_chunk_stats_counts["verified_count"]
|
||||
)
|
||||
|
||||
rejected_scores = raw_chunk_stats_scores.get("rejected_scores", None)
|
||||
if rejected_scores is not None:
|
||||
rejected_avg_scores = rejected_scores / float(
|
||||
raw_chunk_stats_counts["rejected_count"]
|
||||
)
|
||||
else:
|
||||
rejected_avg_scores = None
|
||||
|
||||
chunk_stats = AgentChunkStats(
|
||||
verified_count=raw_chunk_stats_counts["verified_count"],
|
||||
verified_avg_scores=verified_avg_scores,
|
||||
rejected_count=raw_chunk_stats_counts["rejected_count"],
|
||||
rejected_avg_scores=rejected_avg_scores,
|
||||
verified_doc_chunk_ids=verified_doc_chunk_ids,
|
||||
dismissed_doc_chunk_ids=dismissed_doc_chunk_ids,
|
||||
)
|
||||
|
||||
return chunk_stats
|
||||
|
||||
|
||||
def format_results(state: ExpandedRetrievalState) -> ExpandedRetrievalOutput:
|
||||
sub_question_retrieval_stats = _calculate_sub_question_retrieval_stats(
|
||||
verified_documents=state["verified_documents"],
|
||||
expanded_retrieval_results=state["expanded_retrieval_results"],
|
||||
)
|
||||
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
# else:
|
||||
# sub_question_retrieval_stats = [sub_question_retrieval_stats]
|
||||
|
||||
return ExpandedRetrievalOutput(
|
||||
expanded_retrieval_result=ExpandedRetrievalResult(
|
||||
expanded_queries_results=state["expanded_retrieval_results"],
|
||||
all_documents=state["reranked_documents"],
|
||||
sub_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,33 @@
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.core_state import in_subgraph_extract_core_fields
|
||||
from onyx.agent_search.expanded_retrieval.nodes.doc_verification import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
|
||||
|
||||
def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
documents = state["retrieved_documents"]
|
||||
verification_question = state.get(
|
||||
"question", state["subgraph_search_request"].query
|
||||
)
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
Send(
|
||||
node="doc_verification",
|
||||
arg=DocVerificationInput(
|
||||
doc_to_verify=doc,
|
||||
question=verification_question,
|
||||
**in_subgraph_extract_core_fields(state),
|
||||
),
|
||||
)
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
76
backend/onyx/agent_search/expanded_retrieval/states.py
Normal file
76
backend/onyx/agent_search/expanded_retrieval/states.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agent_search.core_state import SubgraphCoreState
|
||||
from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult
|
||||
from onyx.agent_search.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(SubgraphCoreState):
|
||||
question: str
|
||||
dummy: str
|
||||
base_search: bool = False
|
||||
|
||||
|
||||
## Update/Return States
|
||||
|
||||
|
||||
class QueryExpansionUpdate(TypedDict):
|
||||
expanded_queries: list[str]
|
||||
|
||||
|
||||
class DocVerificationUpdate(TypedDict):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class DocRetrievalUpdate(TypedDict):
|
||||
expanded_retrieval_results: Annotated[list[QueryResult], add]
|
||||
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class DocRerankingUpdate(TypedDict):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
sub_question_retrieval_stats: RetrievalFitStats | None
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(TypedDict):
|
||||
expanded_retrieval_result: ExpandedRetrievalResult
|
||||
base_expanded_retrieval_result: ExpandedRetrievalResult
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class ExpandedRetrievalState(
|
||||
# This includes the core state
|
||||
ExpandedRetrievalInput,
|
||||
QueryExpansionUpdate,
|
||||
DocRetrievalUpdate,
|
||||
DocVerificationUpdate,
|
||||
DocRerankingUpdate,
|
||||
ExpandedRetrievalOutput,
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
## Conditional Input States
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalInput):
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
class RetrievalInput(ExpandedRetrievalInput):
|
||||
query_to_retrieve: str
|
||||
87
backend/onyx/agent_search/main/edges.py
Normal file
87
backend/onyx/agent_search/main/edges.py
Normal file
@@ -0,0 +1,87 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionInput
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionOutput
|
||||
from onyx.agent_search.core_state import extract_core_fields_for_subgraph
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
from onyx.agent_search.main.states import MainInput
|
||||
from onyx.agent_search.main.states import MainState
|
||||
|
||||
|
||||
def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]:
|
||||
if len(state["initial_decomp_questions"]) > 0:
|
||||
return [
|
||||
Send(
|
||||
"answer_query",
|
||||
AnswerQuestionInput(
|
||||
**extract_core_fields_for_subgraph(state),
|
||||
question=question,
|
||||
),
|
||||
)
|
||||
for question in state["initial_decomp_questions"]
|
||||
]
|
||||
|
||||
else:
|
||||
return [
|
||||
Send(
|
||||
"ingest_answers",
|
||||
AnswerQuestionOutput(),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def send_to_initial_retrieval(state: MainInput) -> list[Send | Hashable]:
|
||||
print("sending to initial retrieval via edge")
|
||||
return [
|
||||
Send(
|
||||
"initial_retrieval",
|
||||
ExpandedRetrievalInput(
|
||||
question=state["search_request"].query,
|
||||
**extract_core_fields_for_subgraph(state),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
# def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
# # Routes re-written queries to the (parallel) retrieval steps
|
||||
# # Notice the 'Send()' API that takes care of the parallelization
|
||||
# return [
|
||||
# Send(
|
||||
# "sub_answers_graph",
|
||||
# ResearchQAState(
|
||||
# sub_question=sub_question["sub_question_str"],
|
||||
# sub_question_nr=sub_question["sub_question_nr"],
|
||||
# graph_start_time=state["graph_start_time"],
|
||||
# primary_llm=state["primary_llm"],
|
||||
# fast_llm=state["fast_llm"],
|
||||
# ),
|
||||
# )
|
||||
# for sub_question in state["sub_questions"]
|
||||
# ]
|
||||
|
||||
|
||||
# def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
# print("---GO TO DEEP ANSWER OR END---")
|
||||
|
||||
# base_answer = state["base_answer"]
|
||||
|
||||
# question = state["original_question"]
|
||||
|
||||
# BASE_CHECK_MESSAGE = [
|
||||
# HumanMessage(
|
||||
# content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
|
||||
# )
|
||||
# ]
|
||||
|
||||
# model = state["fast_llm"]
|
||||
# response = model.invoke(BASE_CHECK_MESSAGE)
|
||||
|
||||
# print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
|
||||
|
||||
# if response.pretty_repr() == "no":
|
||||
# return "decompose"
|
||||
# else:
|
||||
# return "end"
|
||||
459
backend/onyx/agent_search/main/graph_builder.py
Normal file
459
backend/onyx/agent_search/main/graph_builder.py
Normal file
@@ -0,0 +1,459 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from onyx.agent_search.answer_question.graph_builder import answer_query_graph_builder
|
||||
from onyx.agent_search.base_raw_search.graph_builder import (
|
||||
base_raw_search_graph_builder,
|
||||
)
|
||||
from onyx.agent_search.main.edges import parallelize_decompozed_answer_queries
|
||||
from onyx.agent_search.main.nodes.base_decomp import main_decomp_base
|
||||
from onyx.agent_search.main.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from onyx.agent_search.main.nodes.ingest_answers import ingest_answers
|
||||
from onyx.agent_search.main.nodes.ingest_initial_retrieval import (
|
||||
ingest_initial_retrieval,
|
||||
)
|
||||
from onyx.agent_search.main.states import MainInput
|
||||
from onyx.agent_search.main.states import MainState
|
||||
|
||||
|
||||
test_mode = False
|
||||
|
||||
|
||||
def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
graph_component = "both"
|
||||
# graph_component = "right"
|
||||
# graph_component = "left"
|
||||
|
||||
if graph_component == "left":
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="base_decomp",
|
||||
action=main_decomp_base,
|
||||
)
|
||||
answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_query",
|
||||
action=answer_query_subgraph,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="prep_for_initial_retrieval",
|
||||
# action=prep_for_initial_retrieval,
|
||||
# )
|
||||
|
||||
# expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="initial_retrieval",
|
||||
# action=expanded_retrieval_subgraph,
|
||||
# )
|
||||
|
||||
# base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="base_raw_search_data",
|
||||
# action=base_raw_search_subgraph,
|
||||
# )
|
||||
# graph.add_node(
|
||||
# node="ingest_initial_retrieval",
|
||||
# action=ingest_initial_retrieval,
|
||||
# )
|
||||
graph.add_node(
|
||||
node="ingest_answers",
|
||||
action=ingest_answers,
|
||||
)
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source=START,
|
||||
# path=send_to_initial_retrieval,
|
||||
# path_map=["initial_retrieval"],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="prep_for_initial_retrieval",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="prep_for_initial_retrieval",
|
||||
# end_key="initial_retrieval",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="initial_retrieval",
|
||||
# end_key="ingest_initial_retrieval",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="base_raw_search_data"
|
||||
# )
|
||||
|
||||
# # graph.add_edge(
|
||||
# # start_key="base_raw_search_data",
|
||||
# # end_key=END
|
||||
# # )
|
||||
# graph.add_edge(
|
||||
# start_key="base_raw_search_data",
|
||||
# end_key="ingest_initial_retrieval",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="ingest_initial_retrieval",
|
||||
# end_key=END
|
||||
# )
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="base_decomp",
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
source="base_decomp",
|
||||
path=parallelize_decompozed_answer_queries,
|
||||
path_map=["answer_query"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_query",
|
||||
end_key="ingest_answers",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_answers",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=["ingest_answers", "ingest_initial_retrieval"],
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key=END,
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="ingest_answers",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
# if test_mode:
|
||||
# graph.add_edge(
|
||||
# start_key=["ingest_answers", "ingest_initial_retrieval"],
|
||||
# end_key="generate_initial_base_answer",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key=["generate_initial_answer", "generate_initial_base_answer"],
|
||||
# end_key=END,
|
||||
# )
|
||||
# else:
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
elif graph_component == "right":
|
||||
### Add nodes ###
|
||||
|
||||
# graph.add_node(
|
||||
# node="base_decomp",
|
||||
# action=main_decomp_base,
|
||||
# )
|
||||
# answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="answer_query",
|
||||
# action=answer_query_subgraph,
|
||||
# )
|
||||
|
||||
# graph.add_node(
|
||||
# node="prep_for_initial_retrieval",
|
||||
# action=prep_for_initial_retrieval,
|
||||
# )
|
||||
|
||||
# expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="initial_retrieval",
|
||||
# action=expanded_retrieval_subgraph,
|
||||
# )
|
||||
|
||||
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="base_raw_search_data",
|
||||
action=base_raw_search_subgraph,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_initial_retrieval",
|
||||
action=ingest_initial_retrieval,
|
||||
)
|
||||
# graph.add_node(
|
||||
# node="ingest_answers",
|
||||
# action=ingest_answers,
|
||||
# )
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source=START,
|
||||
# path=send_to_initial_retrieval,
|
||||
# path_map=["initial_retrieval"],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="prep_for_initial_retrieval",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="prep_for_initial_retrieval",
|
||||
# end_key="initial_retrieval",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="initial_retrieval",
|
||||
# end_key="ingest_initial_retrieval",
|
||||
# )
|
||||
|
||||
graph.add_edge(start_key=START, end_key="base_raw_search_data")
|
||||
|
||||
# # graph.add_edge(
|
||||
# # start_key="base_raw_search_data",
|
||||
# # end_key=END
|
||||
# # )
|
||||
graph.add_edge(
|
||||
start_key="base_raw_search_data",
|
||||
end_key="ingest_initial_retrieval",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="ingest_initial_retrieval",
|
||||
# end_key=END
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="base_decomp",
|
||||
# )
|
||||
# graph.add_conditional_edges(
|
||||
# source="base_decomp",
|
||||
# path=parallelize_decompozed_answer_queries,
|
||||
# path_map=["answer_query"],
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="answer_query",
|
||||
# end_key="ingest_answers",
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="ingest_answers",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="ingest_initial_retrieval",
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=["ingest_answers", "ingest_initial_retrieval"],
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key=END,
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="ingest_answers",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
# if test_mode:
|
||||
# graph.add_edge(
|
||||
# start_key=["ingest_answers", "ingest_initial_retrieval"],
|
||||
# end_key="generate_initial_base_answer",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key=["generate_initial_answer", "generate_initial_base_answer"],
|
||||
# end_key=END,
|
||||
# )
|
||||
# else:
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
else:
|
||||
graph.add_node(
|
||||
node="base_decomp",
|
||||
action=main_decomp_base,
|
||||
)
|
||||
answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_query",
|
||||
action=answer_query_subgraph,
|
||||
)
|
||||
|
||||
# graph.add_node(
|
||||
# node="prep_for_initial_retrieval",
|
||||
# action=prep_for_initial_retrieval,
|
||||
# )
|
||||
|
||||
# expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile()
|
||||
# graph.add_node(
|
||||
# node="initial_retrieval",
|
||||
# action=expanded_retrieval_subgraph,
|
||||
# )
|
||||
|
||||
base_raw_search_subgraph = base_raw_search_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="base_raw_search_data",
|
||||
action=base_raw_search_subgraph,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_initial_retrieval",
|
||||
action=ingest_initial_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="ingest_answers",
|
||||
action=ingest_answers,
|
||||
)
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
# if test_mode:
|
||||
# graph.add_node(
|
||||
# node="generate_initial_base_answer",
|
||||
# action=generate_initial_base_answer,
|
||||
# )
|
||||
|
||||
### Add edges ###
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source=START,
|
||||
# path=send_to_initial_retrieval,
|
||||
# path_map=["initial_retrieval"],
|
||||
# )
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key=START,
|
||||
# end_key="prep_for_initial_retrieval",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="prep_for_initial_retrieval",
|
||||
# end_key="initial_retrieval",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key="initial_retrieval",
|
||||
# end_key="ingest_initial_retrieval",
|
||||
# )
|
||||
|
||||
graph.add_edge(start_key=START, end_key="base_raw_search_data")
|
||||
|
||||
# # graph.add_edge(
|
||||
# # start_key="base_raw_search_data",
|
||||
# # end_key=END
|
||||
# # )
|
||||
graph.add_edge(
|
||||
start_key="base_raw_search_data",
|
||||
end_key="ingest_initial_retrieval",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="ingest_initial_retrieval",
|
||||
# end_key=END
|
||||
# )
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="base_decomp",
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
source="base_decomp",
|
||||
path=parallelize_decompozed_answer_queries,
|
||||
path_map=["answer_query"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_query",
|
||||
end_key="ingest_answers",
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="ingest_answers",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key=["ingest_answers", "ingest_initial_retrieval"],
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key=END,
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="ingest_answers",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
# if test_mode:
|
||||
# graph.add_edge(
|
||||
# start_key=["ingest_answers", "ingest_initial_retrieval"],
|
||||
# end_key="generate_initial_base_answer",
|
||||
# )
|
||||
# graph.add_edge(
|
||||
# start_key=["generate_initial_answer", "generate_initial_base_answer"],
|
||||
# end_key=END,
|
||||
# )
|
||||
# else:
|
||||
# graph.add_edge(
|
||||
# start_key="generate_initial_answer",
|
||||
# end_key=END,
|
||||
# )
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
search_request = SearchRequest(query="Who created Excel?")
|
||||
|
||||
inputs = MainInput(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
# print(thing)
|
||||
print()
|
||||
33
backend/onyx/agent_search/main/nodes/base_decomp.py
Normal file
33
backend/onyx/agent_search/main/nodes/base_decomp.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.agent_search.main.states import BaseDecompUpdate
|
||||
from onyx.agent_search.main.states import MainState
|
||||
from onyx.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS,
|
||||
)
|
||||
from onyx.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
|
||||
|
||||
|
||||
def main_decomp_base(state: MainState) -> BaseDecompUpdate:
|
||||
question = state["search_request"].query
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_DECOMPOSITION_PROMPT_QUESTIONS.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
content = response.pretty_repr()
|
||||
list_of_subquestions = clean_and_parse_list_string(content)
|
||||
|
||||
decomp_list: list[str] = [
|
||||
sub_question["sub_question"].strip() for sub_question in list_of_subquestions
|
||||
]
|
||||
|
||||
return BaseDecompUpdate(
|
||||
initial_decomp_questions=decomp_list,
|
||||
)
|
||||
@@ -0,0 +1,33 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.agent_search.main.states import InitialAnswerBASEUpdate
|
||||
from onyx.agent_search.main.states import MainState
|
||||
from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_BASE_PROMPT
|
||||
from onyx.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def generate_initial_base_answer(state: MainState) -> InitialAnswerBASEUpdate:
|
||||
print("---GENERATE INITIAL BASE ANSWER---")
|
||||
|
||||
question = state["search_request"].query
|
||||
original_question_docs = state["all_original_question_documents"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_BASE_PROMPT.format(
|
||||
question=question,
|
||||
context=format_docs(original_question_docs),
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
print()
|
||||
print(
|
||||
f"\n\n---INITIAL BASE ANSWER START---\n\nBase: {answer}\n\n ---INITIAL BASE ANSWER END---\n\n"
|
||||
)
|
||||
return InitialAnswerBASEUpdate(initial_base_answer=answer)
|
||||
180
backend/onyx/agent_search/main/nodes/generate_initial_answer.py
Normal file
180
backend/onyx/agent_search/main/nodes/generate_initial_answer.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.agent_search.answer_question.states import QuestionAnswerResults
|
||||
from onyx.agent_search.main.states import InitialAnswerUpdate
|
||||
from onyx.agent_search.main.states import MainState
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from onyx.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
|
||||
from onyx.agent_search.shared_graph_utils.prompts import (
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS,
|
||||
)
|
||||
from onyx.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def _calculate_initial_agent_stats(
|
||||
decomp_answer_results: list[QuestionAnswerResults],
|
||||
original_question_stats: AgentChunkStats,
|
||||
) -> InitialAgentResultStats:
|
||||
initial_agent_result_stats: InitialAgentResultStats = InitialAgentResultStats(
|
||||
sub_questions={},
|
||||
original_question={},
|
||||
agent_effectiveness={},
|
||||
)
|
||||
|
||||
orig_verified = original_question_stats.verified_count
|
||||
orig_support_score = original_question_stats.verified_avg_scores
|
||||
|
||||
verified_document_chunk_ids = []
|
||||
support_scores = 0.0
|
||||
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
verified_document_chunk_ids += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_doc_chunk_ids
|
||||
)
|
||||
if (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
is not None
|
||||
):
|
||||
support_scores += (
|
||||
decomp_answer_result.sub_question_retrieval_stats.verified_avg_scores
|
||||
)
|
||||
|
||||
verified_document_chunk_ids = list(set(verified_document_chunk_ids))
|
||||
|
||||
# Calculate sub-question stats
|
||||
if (
|
||||
verified_document_chunk_ids
|
||||
and len(verified_document_chunk_ids) > 0
|
||||
and support_scores is not None
|
||||
):
|
||||
sub_question_stats: dict[str, float | int | None] = {
|
||||
"num_verified_documents": len(verified_document_chunk_ids),
|
||||
"verified_avg_score": float(support_scores / len(decomp_answer_results)),
|
||||
}
|
||||
else:
|
||||
sub_question_stats = {"num_verified_documents": 0, "verified_avg_score": None}
|
||||
|
||||
initial_agent_result_stats.sub_questions.update(sub_question_stats)
|
||||
|
||||
# Get original question stats
|
||||
initial_agent_result_stats.original_question.update(
|
||||
{
|
||||
"num_verified_documents": original_question_stats.verified_count,
|
||||
"verified_avg_score": original_question_stats.verified_avg_scores,
|
||||
}
|
||||
)
|
||||
|
||||
# Calculate chunk utilization ratio
|
||||
sub_verified = initial_agent_result_stats.sub_questions["num_verified_documents"]
|
||||
|
||||
chunk_ratio: float | None = None
|
||||
if sub_verified is not None and orig_verified is not None and orig_verified > 0:
|
||||
chunk_ratio = (float(sub_verified) / orig_verified) if sub_verified > 0 else 0.0
|
||||
elif sub_verified is not None and sub_verified > 0:
|
||||
chunk_ratio = 10.0
|
||||
|
||||
initial_agent_result_stats.agent_effectiveness["utilized_chunk_ratio"] = chunk_ratio
|
||||
|
||||
if (
|
||||
orig_support_score is None
|
||||
and initial_agent_result_stats.sub_questions["verified_avg_score"] is None
|
||||
):
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = None
|
||||
elif orig_support_score is None:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 10
|
||||
elif initial_agent_result_stats.sub_questions["verified_avg_score"] is None:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = 0
|
||||
else:
|
||||
initial_agent_result_stats.agent_effectiveness["support_ratio"] = (
|
||||
initial_agent_result_stats.sub_questions["verified_avg_score"]
|
||||
/ orig_support_score
|
||||
)
|
||||
|
||||
return initial_agent_result_stats
|
||||
|
||||
|
||||
def generate_initial_answer(state: MainState) -> InitialAnswerUpdate:
|
||||
print("---GENERATE INITIAL---")
|
||||
|
||||
question = state["search_request"].query
|
||||
sub_question_docs = state["documents"]
|
||||
all_original_question_documents = state["all_original_question_documents"]
|
||||
relevant_docs = dedup_inference_sections(
|
||||
sub_question_docs, all_original_question_documents
|
||||
)
|
||||
|
||||
net_new_original_question_docs = []
|
||||
for all_original_question_doc in all_original_question_documents:
|
||||
if all_original_question_doc not in sub_question_docs:
|
||||
net_new_original_question_docs.append(all_original_question_doc)
|
||||
|
||||
decomp_answer_results = state["decomp_answer_results"]
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
decomp_questions = []
|
||||
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE = """
|
||||
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
|
||||
"""
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
decomp_questions.append(decomp_answer_result.question)
|
||||
if (
|
||||
decomp_answer_result.quality.lower().startswith("yes")
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != "I don't know"
|
||||
):
|
||||
good_qa_list.append(
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.question,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
|
||||
if len(good_qa_list) > 0:
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_PROMPT.format(
|
||||
question=question,
|
||||
answered_sub_questions=sub_question_answer_str,
|
||||
relevant_docs=format_docs(relevant_docs),
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS.format(
|
||||
question=question,
|
||||
relevant_docs=format_docs(relevant_docs),
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
initial_agent_stats = _calculate_initial_agent_stats(
|
||||
state["decomp_answer_results"], state["original_question_retrieval_stats"]
|
||||
)
|
||||
|
||||
print(f"\n\n---INITIAL AGENT ANSWER START---\n\n Answer:\n Agent: {answer}")
|
||||
|
||||
print(f"\n\nSub-Questions:\n\n{sub_question_answer_str}\n\nStas:\n\n")
|
||||
|
||||
if initial_agent_stats:
|
||||
print(initial_agent_stats.original_question)
|
||||
print(initial_agent_stats.sub_questions)
|
||||
print(initial_agent_stats.agent_effectiveness)
|
||||
print("\n\n ---INITIAL AGENT ANSWER END---\n\n")
|
||||
|
||||
return InitialAnswerUpdate(
|
||||
initial_answer=answer,
|
||||
initial_agent_stats=initial_agent_stats,
|
||||
generated_sub_questions=decomp_questions,
|
||||
)
|
||||
16
backend/onyx/agent_search/main/nodes/ingest_answers.py
Normal file
16
backend/onyx/agent_search/main/nodes/ingest_answers.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from onyx.agent_search.answer_question.states import AnswerQuestionOutput
|
||||
from onyx.agent_search.main.states import DecompAnswersUpdate
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
|
||||
|
||||
def ingest_answers(state: AnswerQuestionOutput) -> DecompAnswersUpdate:
|
||||
documents = []
|
||||
answer_results = state.get("answer_results", [])
|
||||
for answer_result in answer_results:
|
||||
documents.extend(answer_result.documents)
|
||||
return DecompAnswersUpdate(
|
||||
# Deduping is done by the documents operator for the main graph
|
||||
# so we might not need to dedup here
|
||||
documents=dedup_inference_sections(documents, []),
|
||||
decomp_answer_results=answer_results,
|
||||
)
|
||||
@@ -0,0 +1,23 @@
|
||||
from onyx.agent_search.base_raw_search.states import BaseRawSearchOutput
|
||||
from onyx.agent_search.main.states import ExpandedRetrievalUpdate
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
|
||||
|
||||
def ingest_initial_retrieval(state: BaseRawSearchOutput) -> ExpandedRetrievalUpdate:
|
||||
sub_question_retrieval_stats = state[
|
||||
"base_expanded_retrieval_result"
|
||||
].sub_question_retrieval_stats
|
||||
if sub_question_retrieval_stats is None:
|
||||
sub_question_retrieval_stats = AgentChunkStats()
|
||||
else:
|
||||
sub_question_retrieval_stats = sub_question_retrieval_stats
|
||||
|
||||
return ExpandedRetrievalUpdate(
|
||||
original_question_retrieval_results=state[
|
||||
"base_expanded_retrieval_result"
|
||||
].expanded_queries_results,
|
||||
all_original_question_documents=state[
|
||||
"base_expanded_retrieval_result"
|
||||
].all_documents,
|
||||
original_question_retrieval_stats=sub_question_retrieval_stats,
|
||||
)
|
||||
@@ -0,0 +1,12 @@
|
||||
from onyx.agent_search.core_state import extract_core_fields_for_subgraph
|
||||
from onyx.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
from onyx.agent_search.main.states import MainState
|
||||
|
||||
|
||||
def prep_for_initial_retrieval(state: MainState) -> ExpandedRetrievalInput:
|
||||
print("prepping")
|
||||
return ExpandedRetrievalInput(
|
||||
question=state["search_request"].query,
|
||||
dummy="0",
|
||||
**extract_core_fields_for_subgraph(state)
|
||||
)
|
||||
77
backend/onyx/agent_search/main/states.py
Normal file
77
backend/onyx/agent_search/main/states.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from onyx.agent_search.answer_question.states import QuestionAnswerResults
|
||||
from onyx.agent_search.core_state import CoreState
|
||||
from onyx.agent_search.expanded_retrieval.models import ExpandedRetrievalResult
|
||||
from onyx.agent_search.expanded_retrieval.models import QueryResult
|
||||
from onyx.agent_search.shared_graph_utils.models import AgentChunkStats
|
||||
from onyx.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
### States ###
|
||||
|
||||
## Update States
|
||||
|
||||
|
||||
class BaseDecompUpdate(TypedDict):
|
||||
initial_decomp_questions: list[str]
|
||||
|
||||
|
||||
class InitialAnswerBASEUpdate(TypedDict):
|
||||
initial_base_answer: str
|
||||
|
||||
|
||||
class InitialAnswerUpdate(TypedDict):
|
||||
initial_answer: str
|
||||
initial_agent_stats: InitialAgentResultStats
|
||||
generated_sub_questions: list[str]
|
||||
|
||||
|
||||
class DecompAnswersUpdate(TypedDict):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
decomp_answer_results: Annotated[list[QuestionAnswerResults], add]
|
||||
|
||||
|
||||
class ExpandedRetrievalUpdate(TypedDict):
|
||||
all_original_question_documents: Annotated[
|
||||
list[InferenceSection], dedup_inference_sections
|
||||
]
|
||||
original_question_retrieval_results: list[QueryResult]
|
||||
original_question_retrieval_stats: AgentChunkStats
|
||||
|
||||
|
||||
## Graph Input State
|
||||
|
||||
|
||||
class MainInput(CoreState):
|
||||
pass
|
||||
|
||||
|
||||
## Graph State
|
||||
|
||||
|
||||
class MainState(
|
||||
# This includes the core state
|
||||
MainInput,
|
||||
BaseDecompUpdate,
|
||||
InitialAnswerUpdate,
|
||||
InitialAnswerBASEUpdate,
|
||||
DecompAnswersUpdate,
|
||||
ExpandedRetrievalUpdate,
|
||||
):
|
||||
# expanded_retrieval_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
base_raw_search_result: Annotated[list[ExpandedRetrievalResult], add]
|
||||
|
||||
|
||||
## Graph Output State
|
||||
|
||||
|
||||
class MainOutput(TypedDict):
|
||||
initial_answer: str
|
||||
initial_base_answer: str
|
||||
initial_agent_stats: dict
|
||||
generated_sub_questions: list[str]
|
||||
120
backend/onyx/agent_search/run_graph.py
Normal file
120
backend/onyx/agent_search/run_graph.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import asyncio
|
||||
from collections.abc import AsyncIterable
|
||||
from collections.abc import Iterable
|
||||
|
||||
from langchain_core.runnables.schema import StreamEvent
|
||||
from langgraph.graph.state import CompiledStateGraph
|
||||
|
||||
from onyx.agent_search.main.graph_builder import main_graph_builder
|
||||
from onyx.agent_search.main.states import MainInput
|
||||
from onyx.chat.answer import AnswerStream
|
||||
from onyx.chat.models import AnswerQuestionPossibleReturn
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool_runner import ToolCallKickoff
|
||||
|
||||
|
||||
def _parse_agent_event(
|
||||
event: StreamEvent,
|
||||
) -> AnswerQuestionPossibleReturn | ToolCallKickoff | ToolResponse | None:
|
||||
"""
|
||||
Parse the event into a typed object.
|
||||
Return None if we are not interested in the event.
|
||||
"""
|
||||
# if event["name"] == "LangGraph":
|
||||
# return None
|
||||
|
||||
event_type = event["event"]
|
||||
langgraph_node = event["metadata"].get("langgraph_node", "_graph_")
|
||||
if "input" in event["data"] and isinstance(event["data"]["input"], str):
|
||||
input_data = f'\nINPUT: {langgraph_node} -- {str(event["data"]["input"])}'
|
||||
else:
|
||||
input_data = ""
|
||||
if "output" in event["data"] and isinstance(event["data"]["output"], str):
|
||||
output_data = f'\nOUTPUT: {langgraph_node} -- {str(event["data"]["output"])}'
|
||||
else:
|
||||
output_data = ""
|
||||
if len(input_data) > 0 or len(output_data) > 0:
|
||||
return input_data + output_data
|
||||
|
||||
event_type = event["event"]
|
||||
if event_type == "tool_call_kickoff":
|
||||
return ToolCallKickoff(**event["data"])
|
||||
elif event_type == "tool_response":
|
||||
return ToolResponse(**event["data"])
|
||||
elif event_type == "on_chat_model_stream":
|
||||
return OnyxAnswerPiece(answer_piece=event["data"]["chunk"].content)
|
||||
return None
|
||||
|
||||
|
||||
def _manage_async_event_streaming(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
graph_input: MainInput,
|
||||
) -> Iterable[StreamEvent]:
|
||||
async def _run_async_event_stream() -> AsyncIterable[StreamEvent]:
|
||||
async for event in compiled_graph.astream_events(
|
||||
input=graph_input,
|
||||
# indicating v2 here deserves further scrutiny
|
||||
version="v2",
|
||||
):
|
||||
yield event
|
||||
|
||||
# This might be able to be simplified
|
||||
def _yield_async_to_sync() -> Iterable[StreamEvent]:
|
||||
loop = asyncio.new_event_loop()
|
||||
try:
|
||||
# Get the async generator
|
||||
async_gen = _run_async_event_stream()
|
||||
# Convert to AsyncIterator
|
||||
async_iter = async_gen.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
# Create a coroutine by calling anext with the async iterator
|
||||
next_coro = anext(async_iter)
|
||||
# Run the coroutine to get the next event
|
||||
event = loop.run_until_complete(next_coro)
|
||||
yield event
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
return _yield_async_to_sync()
|
||||
|
||||
|
||||
def run_graph(
|
||||
compiled_graph: CompiledStateGraph,
|
||||
search_request: SearchRequest,
|
||||
primary_llm: LLM,
|
||||
fast_llm: LLM,
|
||||
) -> AnswerStream:
|
||||
with get_session_context_manager() as db_session:
|
||||
input = MainInput(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
for event in _manage_async_event_streaming(
|
||||
compiled_graph=compiled_graph, graph_input=input
|
||||
):
|
||||
if parsed_object := _parse_agent_event(event):
|
||||
yield parsed_object
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from onyx.llm.factory import get_default_llms
|
||||
from onyx.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="what can you do with onyx or danswer?",
|
||||
)
|
||||
for output in run_graph(compiled_graph, search_request, primary_llm, fast_llm):
|
||||
print("a")
|
||||
# print(output)
|
||||
98
backend/onyx/agent_search/shared_graph_utils/calculations.py
Normal file
98
backend/onyx/agent_search/shared_graph_utils/calculations.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import numpy as np
|
||||
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitScoreMetrics
|
||||
from onyx.agent_search.shared_graph_utils.models import RetrievalFitStats
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def unique_chunk_id(doc: InferenceSection) -> str:
|
||||
return f"{doc.center_chunk.document_id}_{doc.center_chunk.chunk_id}"
|
||||
|
||||
|
||||
def calculate_rank_shift(list1: list, list2: list, top_n: int = 20) -> float:
|
||||
shift = 0
|
||||
for rank_first, doc_id in enumerate(list1[:top_n], 1):
|
||||
try:
|
||||
rank_second = list2.index(doc_id) + 1
|
||||
except ValueError:
|
||||
rank_second = len(list2) # Document not found in second list
|
||||
|
||||
shift += np.abs(rank_first - rank_second) / np.log(1 + rank_first * rank_second)
|
||||
|
||||
return shift / top_n
|
||||
|
||||
|
||||
def get_fit_scores(
|
||||
pre_reranked_results: list[InferenceSection],
|
||||
post_reranked_results: list[InferenceSection] | list[SectionRelevancePiece],
|
||||
) -> RetrievalFitStats | None:
|
||||
"""
|
||||
Calculate retrieval metrics for search purposes
|
||||
"""
|
||||
|
||||
if len(pre_reranked_results) == 0 or len(post_reranked_results) == 0:
|
||||
return None
|
||||
|
||||
ranked_sections = {
|
||||
"initial": pre_reranked_results,
|
||||
"reranked": post_reranked_results,
|
||||
}
|
||||
|
||||
fit_eval: RetrievalFitStats = RetrievalFitStats(
|
||||
fit_score_lift=0,
|
||||
rerank_effect=0,
|
||||
fit_scores={
|
||||
"initial": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
"reranked": RetrievalFitScoreMetrics(scores={}, chunk_ids=[]),
|
||||
},
|
||||
)
|
||||
|
||||
for rank_type, docs in ranked_sections.items():
|
||||
print(f"rank_type: {rank_type}")
|
||||
|
||||
for i in [1, 5, 10]:
|
||||
fit_eval.fit_scores[rank_type].scores[str(i)] = (
|
||||
sum(
|
||||
[
|
||||
float(doc.center_chunk.score)
|
||||
for doc in docs[:i]
|
||||
if type(doc) == InferenceSection
|
||||
and doc.center_chunk.score is not None
|
||||
]
|
||||
)
|
||||
/ i
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = (
|
||||
1
|
||||
/ 3
|
||||
* (
|
||||
fit_eval.fit_scores[rank_type].scores["1"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["5"]
|
||||
+ fit_eval.fit_scores[rank_type].scores["10"]
|
||||
)
|
||||
)
|
||||
|
||||
fit_eval.fit_scores[rank_type].scores["fit_score"] = fit_eval.fit_scores[
|
||||
rank_type
|
||||
].scores["1"]
|
||||
|
||||
fit_eval.fit_scores[rank_type].chunk_ids = [
|
||||
unique_chunk_id(doc) for doc in docs if type(doc) == InferenceSection
|
||||
]
|
||||
|
||||
fit_eval.fit_score_lift = (
|
||||
fit_eval.fit_scores["reranked"].scores["fit_score"]
|
||||
/ fit_eval.fit_scores["initial"].scores["fit_score"]
|
||||
)
|
||||
|
||||
fit_eval.rerank_effect = calculate_rank_shift(
|
||||
fit_eval.fit_scores["initial"].chunk_ids,
|
||||
fit_eval.fit_scores["reranked"].chunk_ids,
|
||||
)
|
||||
|
||||
return fit_eval
|
||||
47
backend/onyx/agent_search/shared_graph_utils/models.py
Normal file
47
backend/onyx/agent_search/shared_graph_utils/models.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class BinaryDecisionWithReasoning(BaseModel):
|
||||
reasoning: str
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class RetrievalFitScoreMetrics(BaseModel):
|
||||
scores: dict[str, float]
|
||||
chunk_ids: list[str]
|
||||
|
||||
|
||||
class RetrievalFitStats(BaseModel):
|
||||
fit_score_lift: float
|
||||
rerank_effect: float
|
||||
fit_scores: dict[str, RetrievalFitScoreMetrics]
|
||||
|
||||
|
||||
class AgentChunkScores(BaseModel):
|
||||
scores: dict[str, dict[str, list[int | float]]]
|
||||
|
||||
|
||||
class AgentChunkStats(BaseModel):
|
||||
verified_count: int | None
|
||||
verified_avg_scores: float | None
|
||||
rejected_count: int | None
|
||||
rejected_avg_scores: float | None
|
||||
verified_doc_chunk_ids: list[str]
|
||||
dismissed_doc_chunk_ids: list[str]
|
||||
|
||||
|
||||
class InitialAgentResultStats(BaseModel):
|
||||
sub_questions: dict[str, float | int | None]
|
||||
original_question: dict[str, float | int | None]
|
||||
agent_effectiveness: dict[str, float | int | None]
|
||||
@@ -0,0 +1,9 @@
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def dedup_inference_sections(
|
||||
list1: list[InferenceSection], list2: list[InferenceSection]
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list1 + list2)
|
||||
return deduped
|
||||
543
backend/onyx/agent_search/shared_graph_utils/prompts.py
Normal file
543
backend/onyx/agent_search/shared_graph_utils/prompts.py
Normal file
@@ -0,0 +1,543 @@
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
|
||||
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
|
||||
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
|
||||
enabling the system to search more broadly.
|
||||
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
|
||||
|
||||
REWRITE_PROMPT_MULTI = """ \n
|
||||
Please create a list of 2-3 sample documents that could answer an original question. Each document
|
||||
should be about as long as the original question. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
|
||||
|
||||
BASE_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the context provided below - and only the
|
||||
provided context - to answer the given question. (Note that the answer is in service of anserwing a broader
|
||||
question, given below as 'motivation'.)
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Make sure that you keep all relevant information, specifically as it concerns to the ultimate goal.
|
||||
(But keep other details as well.)
|
||||
|
||||
\nContext:\n {context} \n
|
||||
|
||||
Motivation:\n {original_question} \n\n
|
||||
\n\n
|
||||
And here is the question I want you to answer based on the context above (with the motivation in mind):
|
||||
\n--\n {question} \n--\n
|
||||
"""
|
||||
|
||||
SUB_CHECK_PROMPT = """
|
||||
Your task is to see whether a given answer addresses a given question.
|
||||
Please do not use any internal knowledge you may have - just focus on whether the answer
|
||||
as given seems to largely address the question as given, or at least addresses part of the question.
|
||||
Here is the question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the suggested answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Does the suggested answer address the question? Please answer with yes or no:"""
|
||||
|
||||
|
||||
BASE_CHECK_PROMPT = """ \n
|
||||
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
|
||||
original question requests a simple, factual answer, and there are no ambiguities, judgements,
|
||||
aggregations, or any other complications that may require extra context. (I.e., if the question is
|
||||
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
VERIFIER_PROMPT = """
|
||||
You are supposed to judge whether a document text contains data or information that is potentially relevant for a question.
|
||||
|
||||
Here is a document text that you can take as a fact:
|
||||
--
|
||||
DOCUMENT INFORMATION:
|
||||
{document_content}
|
||||
--
|
||||
|
||||
Do you think that this information is useful and relevant to answer the following question?
|
||||
(Other documents may supply additional information, so do not worry if the provided information
|
||||
is not enough to answer the question, but it needs to be relevant to the question.)
|
||||
--
|
||||
QUESTION:
|
||||
{question}
|
||||
--
|
||||
|
||||
Please answer with 'yes' or 'no':
|
||||
|
||||
Answer:
|
||||
|
||||
"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
|
||||
If you think it is helpful, please decompose an initial user question into not more
|
||||
than 4 appropriate sub-questions that help to answer the original question.
|
||||
The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system.
|
||||
|
||||
Importantly, if you think it is not needed or helpful, please just return an empty list. That is ok too.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of subquestions:
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_SINGLE = """ \n
|
||||
Please convert an initial user question into a more appropriate search query for retrievel from a
|
||||
document store. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the query: """
|
||||
|
||||
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
|
||||
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
\nQuestion: {question}
|
||||
\nContext: {combined_context} \n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please also provide a search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of json dictionaries with the following format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DECOMPOSE_PROMPT = """ \n
|
||||
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
|
||||
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question.
|
||||
|
||||
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
|
||||
document store, expressed as entities, relationships and terms. You can also think about the types
|
||||
mentioned in brackets
|
||||
|
||||
Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
and or resolve ambiguities
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Don't be too specific unless the original question is specific.
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
#### Consolidations
|
||||
COMBINED_CONTEXT = """-------
|
||||
Below you will find useful information to answer the original question. First, you see a number of
|
||||
sub-questions with their answers. This information should be considered to be more focussed and
|
||||
somewhat more specific to the original question as it tries to contextualized facts.
|
||||
After that will see the documents that were considered to be relevant to answer the original question.
|
||||
|
||||
Here are the sub-questions and their answers:
|
||||
\n\n {deep_answer_context} \n\n
|
||||
\n\n Here are the documents that were considered to be relevant to answer the original question:
|
||||
\n\n {formated_docs} \n\n
|
||||
----------------
|
||||
"""
|
||||
|
||||
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
|
||||
Below you will find a question that we ultimately want to answer (the original question) and a list of
|
||||
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
|
||||
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
|
||||
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
|
||||
motivation and 2 is less relevant.)
|
||||
|
||||
Please rank the motivations in order of relevance for answering the original question. Also, try to
|
||||
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
|
||||
Ultimately, create a list with the motivation numbers where the number of the most relevant
|
||||
motivations comes first.
|
||||
|
||||
Here is the original question:
|
||||
\n\n {original_question} \n\n
|
||||
\n\n Here is the list of sub-question motivations:
|
||||
\n\n {sub_question_explanations} \n\n
|
||||
----------------
|
||||
|
||||
Please think step by step and then generate the ranked list of motivations.
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
{{"reasonning": <explain your reasoning for the ranking>,
|
||||
"ranked_motivations": <ranked list of motivation numbers>}}
|
||||
"""
|
||||
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_QUESTIONS = """
|
||||
If you think it is helpful, please decompose an initial user question into 2 or 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to
|
||||
1) isolate individual entities (i.e., 'compare sales of company A and company B' -> ['what are sales for company A',
|
||||
'what are sales for company B')]
|
||||
2) clarify or disambiguate ambiguous terms (i.e., 'what is our success with company A' -> ['what are our sales with company A',
|
||||
'what is our market share with company A', 'is company A a reference customer for us', etc.])
|
||||
3) if a term or a metric is essentially clear, but it could relate to various components of an entity and you are generally
|
||||
familiar with the entity, then you can decompose the question into sub-questions that are more specific to components
|
||||
(i.e., 'what do we do to improve scalability of product X', 'what do we to to improve scalability of product X',
|
||||
'what do we do to improve stability of product X', ...])
|
||||
|
||||
If you think that a decomposition is not needed or helpful, please just return an empty list. That is ok too.
|
||||
|
||||
Here is the initial question:
|
||||
-------
|
||||
{question}
|
||||
-------
|
||||
Please formulate your answer as a list of json objects with the following format:
|
||||
[{{"sub_question": <sub-question>}}, ...]
|
||||
|
||||
Answer:"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT = """ \n
|
||||
Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
For each sub-question, please also create one search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of json objects with the following format:
|
||||
|
||||
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
INITIAL_RAG_BASE_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists ofa number of documents that were deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "I don't know".
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
Here is the contextual information from the document store:
|
||||
\n ------- \n
|
||||
{context} \n\n\n
|
||||
\n ------- \n
|
||||
And here is the question I want you to answer based on the context above (with the motivation in mind):
|
||||
\n--\n {question} \n--\n
|
||||
Answer:"""
|
||||
|
||||
|
||||
INITIAL_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of answered sub-questions - these are very important(!) and definitely should be
|
||||
considered to answer the question.
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is empty or irrelevant, just say "I don't know".
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise. But also highlight uncertainties you may have should there be substantial ones,
|
||||
or assumptions you made.
|
||||
|
||||
Here is the contextual information:
|
||||
\n-------\n
|
||||
*Answered Sub-questions (these should really matter!):
|
||||
{answered_sub_questions}
|
||||
|
||||
And here are relevant document information that support the sub-question answers, or that are relevant for the actual question:\n
|
||||
|
||||
{relevant_docs}
|
||||
|
||||
\n-------\n
|
||||
\n
|
||||
And here is the question I want you to answer based on the information above:
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n\n
|
||||
Answer:"""
|
||||
|
||||
INITIAL_RAG_PROMPT_NO_SUB_QUESTIONS = """
|
||||
You are an assistant for question-answering tasks. Use the information provided below
|
||||
- and only the provided information - to answer the provided question.
|
||||
The information provided below consists of a number of documents that were deemed relevant for the question.
|
||||
|
||||
IMPORTANT RULES:
|
||||
- If you cannot reliably answer the question solely using the provided information, say that you cannot reliably answer.
|
||||
You may give some additional facts you learned, but do not try to invent an answer.
|
||||
- If the information is irrelevant, just say "I don't know".
|
||||
- If the information is relevant but not fully conclusive, specify that the information is not conclusive and say why.
|
||||
|
||||
Again, you should be sure that the answer is supported by the information provided!
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
Here are is the relevant context information:
|
||||
\n-------\n
|
||||
{relevant_docs}
|
||||
\n-------\n
|
||||
|
||||
And here is the question I want you to answer based on the context above
|
||||
\n--\n
|
||||
{question}
|
||||
\n--\n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ENTITY_TERM_PROMPT = """ \n
|
||||
Based on the original question and the context retieved from a dataset, please generate a list of
|
||||
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
|
||||
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
|
||||
|
||||
\n\n
|
||||
Here is the original question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
And here is the context retrieved:
|
||||
\n ------- \n
|
||||
{context}
|
||||
\n ------- \n
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
|
||||
{{"retrieved_entities_relationships": {{
|
||||
"entities": [{{
|
||||
"entity_name": <assign a name for the entity>,
|
||||
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
|
||||
}}],
|
||||
"relationships": [{{
|
||||
"name": <assign a name for the relationship>,
|
||||
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
|
||||
"entities": [<related entity name 1>, <related entity name 2>]
|
||||
}}],
|
||||
"terms": [{{
|
||||
"term_name": <assign a name for the term>,
|
||||
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
|
||||
"similar_to": <list terms that are similar to this term>
|
||||
}}]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
101
backend/onyx/agent_search/shared_graph_utils/utils.py
Normal file
101
backend/onyx/agent_search/shared_graph_utils/utils.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from onyx.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
return "\n\n".join(doc.combined_content for doc in docs)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove any prefixes/labels before the actual JSON content
|
||||
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
|
||||
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
|
||||
# Try parsing with json.loads first, fall back to ast.literal_eval
|
||||
try:
|
||||
return json.loads(cleaned_string)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return ast.literal_eval(cleaned_string)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
|
||||
entities = entity_term_extraction_dict["entities"]
|
||||
terms = entity_term_extraction_dict["terms"]
|
||||
relationships = entity_term_extraction_dict["relationships"]
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def generate_log_message(
|
||||
message: str,
|
||||
node_start_time: datetime,
|
||||
graph_start_time: datetime | None = None,
|
||||
) -> str:
|
||||
current_time = datetime.now()
|
||||
|
||||
if graph_start_time is not None:
|
||||
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
||||
else:
|
||||
graph_time_str = "N/A"
|
||||
|
||||
node_time_str = _format_time_delta(current_time - node_start_time)
|
||||
|
||||
return f"{graph_time_str} ({node_time_str} s): {message}"
|
||||
57
backend/onyx/configs/dev_configs.py
Normal file
57
backend/onyx/configs/dev_configs.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import os
|
||||
|
||||
from .chat_configs import NUM_RETURNED_HITS
|
||||
|
||||
|
||||
#####
|
||||
# Agent Configs
|
||||
#####
|
||||
|
||||
agent_retrieval_stats_os: bool | str | None = os.environ.get(
|
||||
"AGENT_RETRIEVAL_STATS", False
|
||||
)
|
||||
|
||||
AGENT_RETRIEVAL_STATS: bool = False
|
||||
if isinstance(agent_retrieval_stats_os, str) and agent_retrieval_stats_os == "True":
|
||||
AGENT_RETRIEVAL_STATS = True
|
||||
elif isinstance(agent_retrieval_stats_os, bool) and agent_retrieval_stats_os:
|
||||
AGENT_RETRIEVAL_STATS = True
|
||||
|
||||
agent_max_query_retrieval_results_os: int | str = os.environ.get(
|
||||
"AGENT_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS
|
||||
)
|
||||
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS
|
||||
try:
|
||||
atmqrr = int(agent_max_query_retrieval_results_os)
|
||||
AGENT_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"MAX_AGENT_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_MAX_QUERY_RETRIEVAL_RESULTS}"
|
||||
)
|
||||
|
||||
|
||||
# Reranking agent configs
|
||||
agent_reranking_stats_os: bool | str | None = os.environ.get(
|
||||
"AGENT_RERANKING_TEST", False
|
||||
)
|
||||
AGENT_RERANKING_STATS: bool = False
|
||||
if isinstance(agent_reranking_stats_os, str) and agent_reranking_stats_os == "True":
|
||||
AGENT_RERANKING_STATS = True
|
||||
elif isinstance(agent_reranking_stats_os, bool) and agent_reranking_stats_os:
|
||||
AGENT_RERANKING_STATS = True
|
||||
|
||||
|
||||
agent_reranking_max_query_retrieval_results_os: int | str = os.environ.get(
|
||||
"AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS", NUM_RETURNED_HITS
|
||||
)
|
||||
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS: int = NUM_RETURNED_HITS
|
||||
|
||||
try:
|
||||
atmqrr = int(agent_reranking_max_query_retrieval_results_os)
|
||||
AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = atmqrr
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS must be an integer, got {AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS}"
|
||||
)
|
||||
@@ -25,6 +25,11 @@ class ToolCallSummary(BaseModel__v1):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
# This is a workaround to allow arbitrary types in the model
|
||||
# TODO: Remove this once we have a better solution
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def tool_call_tokens(
|
||||
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer
|
||||
|
||||
@@ -26,10 +26,15 @@ huggingface-hub==0.20.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.54.1
|
||||
langchain==0.3.7
|
||||
langchain-core==0.3.24
|
||||
langchain-openai==0.2.9
|
||||
langchain-text-splitters==0.3.2
|
||||
langchainhub==0.1.21
|
||||
langgraph==0.2.59
|
||||
langgraph-checkpoint==2.0.5
|
||||
langgraph-sdk==0.1.44
|
||||
litellm==1.53.1
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
|
||||
121
backend/tests/regression/answer_quality/agent_test.py
Normal file
121
backend/tests/regression/answer_quality/agent_test.py
Normal file
@@ -0,0 +1,121 @@
|
||||
import csv
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
|
||||
import yaml
|
||||
|
||||
from onyx.agent_search.main.graph_builder import main_graph_builder
|
||||
from onyx.agent_search.main.states import MainInput
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.llm.factory import get_default_llms
|
||||
|
||||
cwd = os.getcwd()
|
||||
CONFIG = yaml.safe_load(
|
||||
open(f"{cwd}/backend/tests/regression/answer_quality/search_test_config.yaml")
|
||||
)
|
||||
INPUT_DIR = CONFIG["agent_test_input_folder"]
|
||||
OUTPUT_DIR = CONFIG["agent_test_output_folder"]
|
||||
|
||||
|
||||
graph = main_graph_builder(test_mode=True)
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
|
||||
# create a local json test data file and use it here
|
||||
|
||||
|
||||
input_file_object = open(
|
||||
f"{INPUT_DIR}/agent_test_data.json",
|
||||
)
|
||||
output_file = f"{OUTPUT_DIR}/agent_test_output.csv"
|
||||
|
||||
|
||||
test_data = json.load(input_file_object)
|
||||
example_data = test_data["examples"]
|
||||
example_ids = test_data["example_ids"]
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
output_data = []
|
||||
|
||||
for example in example_data:
|
||||
example_id = example["id"]
|
||||
if len(example_ids) > 0 and example_id not in example_ids:
|
||||
continue
|
||||
|
||||
example_question = example["question"]
|
||||
target_sub_questions = example.get("target_sub_questions", [])
|
||||
num_target_sub_questions = len(target_sub_questions)
|
||||
search_request = SearchRequest(query=example_question)
|
||||
|
||||
inputs = MainInput(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
start_time = datetime.datetime.now()
|
||||
|
||||
question_result = compiled_graph.invoke(input=inputs)
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
duration = end_time - start_time
|
||||
if num_target_sub_questions > 0:
|
||||
chunk_expansion_ratio = (
|
||||
question_result["initial_agent_stats"]
|
||||
.get("agent_effectiveness", {})
|
||||
.get("utilized_chunk_ratio", None)
|
||||
)
|
||||
support_effectiveness_ratio = (
|
||||
question_result["initial_agent_stats"]
|
||||
.get("agent_effectiveness", {})
|
||||
.get("support_ratio", None)
|
||||
)
|
||||
else:
|
||||
chunk_expansion_ratio = None
|
||||
support_effectiveness_ratio = None
|
||||
|
||||
generated_sub_questions = question_result.get("generated_sub_questions", [])
|
||||
num_generated_sub_questions = len(generated_sub_questions)
|
||||
base_answer = question_result["initial_base_answer"].split("==")[-1]
|
||||
agent_answer = question_result["initial_answer"].split("==")[-1]
|
||||
|
||||
output_point = {
|
||||
"example_id": example_id,
|
||||
"question": example_question,
|
||||
"duration": duration,
|
||||
"target_sub_questions": target_sub_questions,
|
||||
"generated_sub_questions": generated_sub_questions,
|
||||
"num_target_sub_questions": num_target_sub_questions,
|
||||
"num_generated_sub_questions": num_generated_sub_questions,
|
||||
"chunk_expansion_ratio": chunk_expansion_ratio,
|
||||
"support_effectiveness_ratio": support_effectiveness_ratio,
|
||||
"base_answer": base_answer,
|
||||
"agent_answer": agent_answer,
|
||||
}
|
||||
|
||||
output_data.append(output_point)
|
||||
|
||||
|
||||
with open(output_file, "w", newline="") as csvfile:
|
||||
fieldnames = [
|
||||
"example_id",
|
||||
"question",
|
||||
"duration",
|
||||
"target_sub_questions",
|
||||
"generated_sub_questions",
|
||||
"num_target_sub_questions",
|
||||
"num_generated_sub_questions",
|
||||
"chunk_expansion_ratio",
|
||||
"support_effectiveness_ratio",
|
||||
"base_answer",
|
||||
"agent_answer",
|
||||
]
|
||||
|
||||
writer = csv.DictWriter(csvfile, fieldnames=fieldnames, delimiter="\t")
|
||||
writer.writeheader()
|
||||
writer.writerows(output_data)
|
||||
|
||||
print("DONE")
|
||||
Reference in New Issue
Block a user