mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
10 Commits
v2.5.9
...
initial_im
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0d848fa9dd | ||
|
|
1dbb6f3d69 | ||
|
|
91cf9a5472 | ||
|
|
ebb0e56a30 | ||
|
|
091cb136c4 | ||
|
|
56052c5b4b | ||
|
|
617726207b | ||
|
|
1be58e74b3 | ||
|
|
a693c991d7 | ||
|
|
4b28686721 |
42
backend/danswer/agent_search/core_qa_graph/edges.py
Normal file
42
backend/danswer/agent_search/core_qa_graph/edges.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
|
||||
|
||||
def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes each de-douped retrieved doc to the verifier step - in parallel
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
|
||||
return [
|
||||
Send(
|
||||
"sub_verifier",
|
||||
VerifierState(
|
||||
document=doc,
|
||||
#question=state["original_question"],
|
||||
question=state["sub_question_str"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for doc in state["sub_question_deduped_retrieval_docs"]
|
||||
]
|
||||
|
||||
|
||||
def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
rewritten_queries = state["sub_question_search_queries"].rewritten_queries + [state["sub_question_str"]]
|
||||
return [
|
||||
Send(
|
||||
"sub_custom_retrieve",
|
||||
RetrieverState(
|
||||
rewritten_query=query,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for query in rewritten_queries
|
||||
]
|
||||
132
backend/danswer/agent_search/core_qa_graph/graph_builder.py
Normal file
132
backend/danswer/agent_search/core_qa_graph/graph_builder.py
Normal file
@@ -0,0 +1,132 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_retrieval
|
||||
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_verifier
|
||||
from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import (
|
||||
sub_combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import (
|
||||
sub_custom_retrieve,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.dummy import sub_dummy
|
||||
from danswer.agent_search.core_qa_graph.nodes.final_format import (
|
||||
sub_final_format,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.generate import sub_generate
|
||||
from danswer.agent_search.core_qa_graph.nodes.qa_check import sub_qa_check
|
||||
from danswer.agent_search.core_qa_graph.nodes.rewrite import sub_rewrite
|
||||
from danswer.agent_search.core_qa_graph.nodes.verifier import sub_verifier
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAOutputState
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.core_qa_graph.states import CoreQAInputState
|
||||
|
||||
|
||||
def build_core_qa_graph() -> StateGraph:
|
||||
sub_answers_initial = StateGraph(
|
||||
state_schema=BaseQAState,
|
||||
output=BaseQAOutputState,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy)
|
||||
sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_custom_retrieve",
|
||||
action=sub_custom_retrieve,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_combine_retrieved_docs",
|
||||
action=sub_combine_retrieved_docs,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_verifier",
|
||||
action=sub_verifier,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_generate",
|
||||
action=sub_generate,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_qa_check",
|
||||
action=sub_qa_check,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_final_format",
|
||||
action=sub_final_format,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
sub_answers_initial.add_edge(START, "sub_dummy")
|
||||
sub_answers_initial.add_edge("sub_dummy", "sub_rewrite")
|
||||
|
||||
sub_answers_initial.add_conditional_edges(
|
||||
source="sub_rewrite",
|
||||
path=sub_continue_to_retrieval,
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_custom_retrieve",
|
||||
end_key="sub_combine_retrieved_docs",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_conditional_edges(
|
||||
source="sub_combine_retrieved_docs",
|
||||
path=sub_continue_to_verifier,
|
||||
path_map=["sub_verifier"],
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_verifier",
|
||||
end_key="sub_generate",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_generate",
|
||||
end_key="sub_qa_check",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_qa_check",
|
||||
end_key="sub_final_format",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_final_format",
|
||||
end_key=END,
|
||||
)
|
||||
# sub_answers_graph = sub_answers_initial.compile()
|
||||
return sub_answers_initial
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# q = "Whose music is kind of hard to easily enjoy?"
|
||||
# q = "What is voice leading?"
|
||||
# q = "What are the types of motions in music?"
|
||||
# q = "What are key elements of music theory?"
|
||||
# q = "How can I best understand music theory using voice leading?"
|
||||
q = "What makes good music?"
|
||||
# q = "types of motions in music"
|
||||
# q = "What is the relationship between music and physics?"
|
||||
# q = "Can you compare various grunge styles?"
|
||||
# q = "Why is quantum gravity so hard?"
|
||||
|
||||
inputs = CoreQAInputState(
|
||||
original_question=q,
|
||||
sub_question_str=q,
|
||||
)
|
||||
sub_answers_graph = build_core_qa_graph()
|
||||
compiled_sub_answers = sub_answers_graph.compile()
|
||||
output = compiled_sub_answers.invoke(inputs)
|
||||
print("\nOUTPUT:")
|
||||
print(output.keys())
|
||||
for key, value in output.items():
|
||||
if key in [
|
||||
"sub_question_answer",
|
||||
"sub_question_str",
|
||||
"sub_qas",
|
||||
"initial_sub_qas",
|
||||
"sub_question_answer",
|
||||
]:
|
||||
print(f"{key}: {value}")
|
||||
@@ -0,0 +1,36 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
|
||||
dedupe_docs: list[InferenceSection] = []
|
||||
for base_retrieval_doc in sub_question_base_retrieval_docs:
|
||||
if not any(
|
||||
base_retrieval_doc.center_chunk.chunk_id == doc.center_chunk.chunk_id
|
||||
for doc in dedupe_docs
|
||||
):
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
|
||||
return {
|
||||
"sub_question_deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE SUB---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
rewritten_query = state["rewritten_query"]
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
documents: list[InferenceSection] = []
|
||||
llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=rewritten_query,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
reranked_docs = documents.reranked_sections
|
||||
|
||||
# initial metric to measure fit TODO: implement metric properly
|
||||
|
||||
top_1_score = reranked_docs[0].center_chunk.score
|
||||
top_5_score = sum([doc.center_chunk.score for doc in reranked_docs[:5]]) / 5
|
||||
top_10_score = sum([doc.center_chunk.score for doc in reranked_docs[:10]]) / 10
|
||||
|
||||
fit_score = 1/3 * (top_1_score + top_5_score + top_10_score)
|
||||
|
||||
chunk_ids = {'query': rewritten_query,
|
||||
'chunk_ids': [doc.center_chunk.chunk_id for doc in reranked_docs]}
|
||||
|
||||
|
||||
return {
|
||||
"sub_question_base_retrieval_docs": reranked_docs,
|
||||
"sub_chunk_ids": [chunk_ids],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - custom_retrieve, fit_score: {fit_score}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
24
backend/danswer/agent_search/core_qa_graph/nodes/dummy.py
Normal file
24
backend/danswer/agent_search/core_qa_graph/nodes/dummy.py
Normal file
@@ -0,0 +1,24 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy step
|
||||
"""
|
||||
|
||||
print("---Sub Dummy---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
return {
|
||||
"graph_start_time": node_start_time,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - dummy",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=node_start_time,
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
|
||||
|
||||
def sub_final_format(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Create the final output for the QA subgraph
|
||||
"""
|
||||
|
||||
print("---BASE FINAL FORMAT---")
|
||||
|
||||
return {
|
||||
"sub_qas": [
|
||||
{
|
||||
"sub_question": state["sub_question_str"],
|
||||
"sub_answer": state["sub_question_answer"],
|
||||
"sub_answer_check": state["sub_question_answer_check"],
|
||||
}
|
||||
],
|
||||
"log_messages": state["log_messages"],
|
||||
}
|
||||
91
backend/danswer/agent_search/core_qa_graph/nodes/generate.py
Normal file
91
backend/danswer/agent_search/core_qa_graph/nodes/generate.py
Normal file
@@ -0,0 +1,91 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_generate(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
|
||||
# Create sub-query results
|
||||
|
||||
verified_chunks = [chunk.center_chunk.chunk_id for chunk in state["sub_question_verified_retrieval_docs"]]
|
||||
result_dict = {}
|
||||
|
||||
chunk_id_dicts = state["sub_chunk_ids"]
|
||||
expanded_chunks = []
|
||||
original_chunks = []
|
||||
|
||||
for chunk_id_dict in chunk_id_dicts:
|
||||
sub_question = chunk_id_dict['query']
|
||||
verified_sq_chunks = [chunk_id for chunk_id in chunk_id_dict['chunk_ids'] if chunk_id in verified_chunks]
|
||||
|
||||
if sub_question != state["original_question"]:
|
||||
expanded_chunks += verified_sq_chunks
|
||||
else:
|
||||
result_dict['ORIGINAL'] = len(verified_sq_chunks)
|
||||
original_chunks += verified_sq_chunks
|
||||
result_dict[sub_question[:30]] = len(verified_sq_chunks)
|
||||
|
||||
expansion_chunks = set(expanded_chunks)
|
||||
num_expansion_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id in verified_chunks])
|
||||
num_original_relevant_chunks = len(original_chunks)
|
||||
num_missed_relevant_chunks = sum([1 for chunk_id in original_chunks if chunk_id not in expansion_chunks])
|
||||
num_gained_relevant_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id not in original_chunks])
|
||||
result_dict['expansion_chunks'] = num_expansion_chunks
|
||||
|
||||
|
||||
|
||||
print(result_dict)
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question_str"]
|
||||
docs = state["sub_question_verified_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
# Only take the top 10 docs.
|
||||
# TODO: Make this dynamic or use config param?
|
||||
top_10_docs = docs[-10:]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(top_10_docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
_, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
return {
|
||||
"sub_question_answer": answer_str,
|
||||
"log_messages": generate_log_message(
|
||||
message="base - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
51
backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py
Normal file
51
backend/danswer/agent_search/core_qa_graph/nodes/qa_check.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_qa_check(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Check if the sub-question answer is satisfactory.
|
||||
|
||||
Args:
|
||||
state: The current SubQAState containing the sub-question and its answer
|
||||
|
||||
Returns:
|
||||
dict containing the check result and log message
|
||||
"""
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_CHECK_PROMPT.format(
|
||||
question=state["sub_question_str"],
|
||||
base_answer=state["sub_question_answer"],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
_, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
response_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
return {
|
||||
"sub_question_answer_check": response_str,
|
||||
"base_answer_messages": generate_log_message(
|
||||
message="sub - qa_check",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
74
backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py
Normal file
74
backend/danswer/agent_search/core_qa_graph/nodes/rewrite.py
Normal file
@@ -0,0 +1,74 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_rewrite(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
|
||||
print("---SUB TRANSFORM QUERY---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
# messages = state["base_answer_messages"]
|
||||
question = state["sub_question_str"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
"""
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
_, fast_llm = get_default_llms()
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
|
||||
# structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
print(f"llm_response: {llm_response}")
|
||||
|
||||
rewritten_queries = llm_response.split("--")
|
||||
# rewritten_queries = [llm_response.split("\n")[0]]
|
||||
|
||||
print(f"rewritten_queries: {rewritten_queries}")
|
||||
|
||||
rewritten_queries = RewrittenQueries(rewritten_queries=rewritten_queries)
|
||||
|
||||
return {
|
||||
"sub_question_search_queries": rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
64
backend/danswer/agent_search/core_qa_graph/nodes/verifier.py
Normal file
64
backend/danswer/agent_search/core_qa_graph/nodes/verifier.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
# print("---VERIFY QUTPUT---")
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
response_string = merge_message_runs(response, chunk_separator="")[0].content
|
||||
# Convert string response to proper dictionary format
|
||||
decision_dict = {"decision": response_string.lower()}
|
||||
formatted_response = BinaryDecision.model_validate(decision_dict)
|
||||
|
||||
print(f"Verification end time: {datetime.datetime.now()}")
|
||||
|
||||
return {
|
||||
"sub_question_verified_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
90
backend/danswer/agent_search/core_qa_graph/states.py
Normal file
90
backend/danswer/agent_search/core_qa_graph/states.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
class SubQuestionRetrieverState(TypedDict):
|
||||
# The state for the parallel Retrievers. They each need to see only one query
|
||||
sub_question_rewritten_query: str
|
||||
|
||||
|
||||
class SubQuestionVerifierState(TypedDict):
|
||||
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
|
||||
sub_question_document: InferenceSection
|
||||
sub_question: str
|
||||
|
||||
|
||||
class CoreQAInputState(TypedDict):
|
||||
sub_question_str: str
|
||||
original_question: str
|
||||
|
||||
|
||||
class BaseQAState(TypedDict):
|
||||
# The 'core SubQuestion' state.
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
# start time for parallel initial sub-questionn thread
|
||||
sub_query_start_time: datetime
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question_str: str
|
||||
sub_question_search_queries: RewrittenQueries
|
||||
sub_question_nr: int
|
||||
sub_chunk_ids: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
# Answers sent back to core
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
|
||||
class BaseQAOutputState(TypedDict):
|
||||
# The 'SubQuestion' output state. Removes all the intermediate states
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question_str: str
|
||||
sub_question_search_queries: list[str]
|
||||
sub_question_nr: int
|
||||
# Answers sent back to core
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
# Answers sent back to core
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
46
backend/danswer/agent_search/deep_qa_graph/edges.py
Normal file
46
backend/danswer/agent_search/deep_qa_graph/edges.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
|
||||
|
||||
def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes each de-douped retrieved doc to the verifier step - in parallel
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
|
||||
return [
|
||||
Send(
|
||||
"sub_verifier",
|
||||
VerifierState(
|
||||
document=doc,
|
||||
question=state["sub_question"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for doc in state["sub_question_base_retrieval_docs"]
|
||||
]
|
||||
|
||||
|
||||
def sub_continue_to_retrieval(
|
||||
state: ResearchQAState,
|
||||
) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_custom_retrieve",
|
||||
RetrieverState(
|
||||
rewritten_query=query,
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for query in state["sub_question_rewritten_queries"]
|
||||
]
|
||||
93
backend/danswer/agent_search/deep_qa_graph/graph_builder.py
Normal file
93
backend/danswer/agent_search/deep_qa_graph/graph_builder.py
Normal file
@@ -0,0 +1,93 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_retrieval
|
||||
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_verifier
|
||||
from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import (
|
||||
sub_combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import sub_custom_retrieve
|
||||
from danswer.agent_search.deep_qa_graph.nodes.dummy import sub_dummy
|
||||
from danswer.agent_search.deep_qa_graph.nodes.final_format import sub_final_format
|
||||
from danswer.agent_search.deep_qa_graph.nodes.generate import sub_generate
|
||||
from danswer.agent_search.deep_qa_graph.nodes.qa_check import sub_qa_check
|
||||
from danswer.agent_search.deep_qa_graph.nodes.verifier import sub_verifier
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
|
||||
|
||||
def build_deep_qa_graph() -> StateGraph:
|
||||
# Define the nodes we will cycle between
|
||||
sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState)
|
||||
|
||||
### Add Nodes ###
|
||||
|
||||
# Dummy node for initial processing
|
||||
sub_answers.add_node(node="sub_dummy", action=sub_dummy)
|
||||
|
||||
# The retrieval step
|
||||
sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve)
|
||||
|
||||
# The dedupe step
|
||||
sub_answers.add_node(
|
||||
node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs
|
||||
)
|
||||
|
||||
# Verifying retrieved information
|
||||
sub_answers.add_node(node="sub_verifier", action=sub_verifier)
|
||||
|
||||
# Generating the response
|
||||
sub_answers.add_node(node="sub_generate", action=sub_generate)
|
||||
|
||||
# Checking the quality of the answer
|
||||
sub_answers.add_node(node="sub_qa_check", action=sub_qa_check)
|
||||
|
||||
# Final formatting of the response
|
||||
sub_answers.add_node(node="sub_final_format", action=sub_final_format)
|
||||
|
||||
### Add Edges ###
|
||||
|
||||
# Generate multiple sub-questions
|
||||
sub_answers.add_edge(start_key=START, end_key="sub_rewrite")
|
||||
|
||||
# For each sub-question, perform a retrieval in parallel
|
||||
sub_answers.add_conditional_edges(
|
||||
source="sub_rewrite",
|
||||
path=sub_continue_to_retrieval,
|
||||
path_map=["sub_custom_retrieve"],
|
||||
)
|
||||
|
||||
# Combine the retrieved docs for each sub-question from the parallel retrievals
|
||||
sub_answers.add_edge(
|
||||
start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs"
|
||||
)
|
||||
|
||||
# Go over all of the combined retrieved docs and verify them against the original question
|
||||
sub_answers.add_conditional_edges(
|
||||
source="sub_combine_retrieved_docs",
|
||||
path=sub_continue_to_verifier,
|
||||
path_map=["sub_verifier"],
|
||||
)
|
||||
|
||||
# Generate an answer for each verified retrieved doc
|
||||
sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate")
|
||||
|
||||
# Check the quality of the answer
|
||||
sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check")
|
||||
|
||||
sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format")
|
||||
|
||||
sub_answers.add_edge(start_key="sub_final_format", end_key=END)
|
||||
|
||||
return sub_answers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: add the actual question
|
||||
inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"}
|
||||
sub_answers_graph = build_deep_qa_graph()
|
||||
compiled_sub_answers = sub_answers_graph.compile()
|
||||
output = compiled_sub_answers.invoke(inputs)
|
||||
print("\nOUTPUT:")
|
||||
print(output)
|
||||
@@ -0,0 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
|
||||
dedupe_docs = []
|
||||
for base_retrieval_doc in sub_question_base_retrieval_docs:
|
||||
if base_retrieval_doc not in dedupe_docs:
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
return {
|
||||
"sub_question_deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE SUB---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
documents: list[InferenceSection] = []
|
||||
|
||||
return {
|
||||
"sub_question_base_retrieval_docs": documents,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - custom_retrieve",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
21
backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py
Normal file
21
backend/danswer/agent_search/deep_qa_graph/nodes/dummy.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy step
|
||||
"""
|
||||
|
||||
print("---Sub Dummy---")
|
||||
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - dummy",
|
||||
node_start_time=datetime.now(),
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_final_format(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Create the final output for the QA subgraph
|
||||
"""
|
||||
|
||||
print("---SUB FINAL FORMAT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
return {
|
||||
# TODO: Type this
|
||||
"sub_qas": [
|
||||
{
|
||||
"sub_question": state["sub_question"],
|
||||
"sub_answer": state["sub_question_answer"],
|
||||
"sub_question_nr": state["sub_question_nr"],
|
||||
"sub_answer_check": state["sub_question_answer_check"],
|
||||
}
|
||||
],
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - final format",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
56
backend/danswer/agent_search/deep_qa_graph/nodes/generate.py
Normal file
56
backend/danswer/agent_search/deep_qa_graph/nodes/generate.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_generate(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---SUB GENERATE---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question"]
|
||||
docs = state["sub_question_verified_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs for sub-question: {len(docs)}")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
if len(docs) > 0:
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
response_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
else:
|
||||
response_str = ""
|
||||
|
||||
return {
|
||||
"sub_question_answer": response_str,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
57
backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py
Normal file
57
backend/danswer/agent_search/deep_qa_graph/nodes/qa_check.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.prompts import SUB_CHECK_PROMPT
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_qa_check(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the final output satisfies the original user question
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---CHECK SUB QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_answer = state["sub_question_answer"]
|
||||
sub_question = state["sub_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
sub_question=sub_question, sub_answer=sub_answer
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"sub_question_answer_check": formatted_response.decision,
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - qa check: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
64
backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py
Normal file
64
backend/danswer/agent_search/deep_qa_graph/nodes/rewrite.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
def sub_rewrite(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
|
||||
print("---SUB TRANSFORM QUERY---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
fast_llm: LLM = state["fast_llm"]
|
||||
llm_response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
|
||||
|
||||
print(f"rewritten_queries: {rewritten_queries}")
|
||||
|
||||
rewritten_queries = RewrittenQueries(
|
||||
rewritten_queries=[
|
||||
"music hard to listen to",
|
||||
"Music that is not fun or pleasant",
|
||||
]
|
||||
)
|
||||
|
||||
print(f"hardcoded rewritten_queries: {rewritten_queries}")
|
||||
|
||||
return {
|
||||
"sub_question_rewritten_queries": rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
59
backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py
Normal file
59
backend/danswer/agent_search/deep_qa_graph/nodes/verifier.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---SUB VERIFY QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"core - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
13
backend/danswer/agent_search/deep_qa_graph/prompts.py
Normal file
13
backend/danswer/agent_search/deep_qa_graph/prompts.py
Normal file
@@ -0,0 +1,13 @@
|
||||
SUB_CHECK_PROMPT = """ \n
|
||||
Please check whether the suggested answer seems to address the original question.
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
64
backend/danswer/agent_search/deep_qa_graph/states.py
Normal file
64
backend/danswer/agent_search/deep_qa_graph/states.py
Normal file
@@ -0,0 +1,64 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
class ResearchQAState(TypedDict):
|
||||
# The 'core SubQuestion' state.
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question: str
|
||||
sub_question_nr: int
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
|
||||
class ResearchQAOutputState(TypedDict):
|
||||
# The 'SubQuestion' output state. Removes all the intermediate states
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question: str
|
||||
sub_question_nr: int
|
||||
# Answers sent back to core
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
75
backend/danswer/agent_search/primary_graph/edges.py
Normal file
75
backend/danswer/agent_search/primary_graph/edges.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
|
||||
|
||||
|
||||
def continue_to_initial_sub_questions(
|
||||
state: QAState,
|
||||
) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_answers_graph_initial",
|
||||
BaseQAState(
|
||||
sub_question_str=initial_sub_question["sub_question_str"],
|
||||
sub_question_search_queries=initial_sub_question[
|
||||
"sub_question_search_queries"
|
||||
],
|
||||
sub_question_nr=initial_sub_question["sub_question_nr"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for initial_sub_question in state["initial_sub_questions"]
|
||||
]
|
||||
|
||||
|
||||
def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_answers_graph",
|
||||
ResearchQAState(
|
||||
sub_question=sub_question["sub_question_str"],
|
||||
sub_question_nr=sub_question["sub_question_nr"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
),
|
||||
)
|
||||
for sub_question in state["sub_questions"]
|
||||
]
|
||||
|
||||
|
||||
def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
print("---GO TO DEEP ANSWER OR END---")
|
||||
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
BASE_CHECK_MESSAGE = [
|
||||
HumanMessage(
|
||||
content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
|
||||
)
|
||||
]
|
||||
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(BASE_CHECK_MESSAGE)
|
||||
|
||||
print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
|
||||
|
||||
if response.pretty_repr() == "no":
|
||||
return "decompose"
|
||||
else:
|
||||
return "end"
|
||||
171
backend/danswer/agent_search/primary_graph/graph_builder.py
Normal file
171
backend/danswer/agent_search/primary_graph/graph_builder.py
Normal file
@@ -0,0 +1,171 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.core_qa_graph.graph_builder import build_core_qa_graph
|
||||
from danswer.agent_search.deep_qa_graph.graph_builder import build_deep_qa_graph
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_deep_answer
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions
|
||||
from danswer.agent_search.primary_graph.nodes.base_wait import base_wait
|
||||
from danswer.agent_search.primary_graph.nodes.combine_retrieved_docs import (
|
||||
combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.custom_retrieve import custom_retrieve
|
||||
from danswer.agent_search.primary_graph.nodes.decompose import decompose
|
||||
from danswer.agent_search.primary_graph.nodes.deep_answer_generation import (
|
||||
deep_answer_generation,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.dummy_start import dummy_start
|
||||
from danswer.agent_search.primary_graph.nodes.entity_term_extraction import (
|
||||
entity_term_extraction,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.final_stuff import final_stuff
|
||||
from danswer.agent_search.primary_graph.nodes.generate_initial import generate_initial
|
||||
from danswer.agent_search.primary_graph.nodes.main_decomp_base import main_decomp_base
|
||||
from danswer.agent_search.primary_graph.nodes.rewrite import rewrite
|
||||
from danswer.agent_search.primary_graph.nodes.sub_qa_level_aggregator import (
|
||||
sub_qa_level_aggregator,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.sub_qa_manager import sub_qa_manager
|
||||
from danswer.agent_search.primary_graph.nodes.verifier import verifier
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
|
||||
|
||||
def build_core_graph() -> StateGraph:
|
||||
# Define the nodes we will cycle between
|
||||
core_answer_graph = StateGraph(state_schema=QAState)
|
||||
|
||||
### Add Nodes ###
|
||||
core_answer_graph.add_node(node="dummy_start",
|
||||
action=dummy_start)
|
||||
|
||||
# Re-writing the question
|
||||
core_answer_graph.add_node(node="rewrite",
|
||||
action=rewrite)
|
||||
|
||||
# The retrieval step
|
||||
core_answer_graph.add_node(node="custom_retrieve",
|
||||
action=custom_retrieve)
|
||||
|
||||
# Combine and dedupe retrieved docs.
|
||||
core_answer_graph.add_node(
|
||||
node="combine_retrieved_docs",
|
||||
action=combine_retrieved_docs
|
||||
)
|
||||
|
||||
# Extract entities, terms and relationships
|
||||
core_answer_graph.add_node(
|
||||
node="entity_term_extraction",
|
||||
action=entity_term_extraction
|
||||
)
|
||||
|
||||
# Verifying that a retrieved doc is relevant
|
||||
core_answer_graph.add_node(node="verifier",
|
||||
action=verifier)
|
||||
|
||||
# Initial question decomposition
|
||||
core_answer_graph.add_node(node="main_decomp_base",
|
||||
action=main_decomp_base)
|
||||
|
||||
# Build the base QA sub-graph and compile it
|
||||
compiled_core_qa_graph = build_core_qa_graph().compile()
|
||||
# Add the compiled base QA sub-graph as a node to the core graph
|
||||
core_answer_graph.add_node(
|
||||
node="sub_answers_graph_initial",
|
||||
action=compiled_core_qa_graph
|
||||
)
|
||||
|
||||
# Checking whether the initial answer is in the ballpark
|
||||
core_answer_graph.add_node(node="base_wait",
|
||||
action=base_wait)
|
||||
|
||||
# Decompose the question into sub-questions
|
||||
core_answer_graph.add_node(node="decompose",
|
||||
action=decompose)
|
||||
|
||||
# Manage the sub-questions
|
||||
core_answer_graph.add_node(node="sub_qa_manager",
|
||||
action=sub_qa_manager)
|
||||
|
||||
# Build the research QA sub-graph and compile it
|
||||
compiled_deep_qa_graph = build_deep_qa_graph().compile()
|
||||
# Add the compiled research QA sub-graph as a node to the core graph
|
||||
core_answer_graph.add_node(node="sub_answers_graph",
|
||||
action=compiled_deep_qa_graph)
|
||||
|
||||
# Aggregate the sub-questions
|
||||
core_answer_graph.add_node(
|
||||
node="sub_qa_level_aggregator",
|
||||
action=sub_qa_level_aggregator
|
||||
)
|
||||
|
||||
# aggregate sub questions and answers
|
||||
core_answer_graph.add_node(
|
||||
node="deep_answer_generation",
|
||||
action=deep_answer_generation
|
||||
)
|
||||
|
||||
# A final clean-up step
|
||||
core_answer_graph.add_node(node="final_stuff",
|
||||
action=final_stuff)
|
||||
|
||||
# Generating a response after we know the documents are relevant
|
||||
core_answer_graph.add_node(node="generate_initial",
|
||||
action=generate_initial)
|
||||
|
||||
### Add Edges ###
|
||||
|
||||
# start the initial sub-question decomposition
|
||||
core_answer_graph.add_edge(start_key=START,
|
||||
end_key="main_decomp_base")
|
||||
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="main_decomp_base",
|
||||
path=continue_to_initial_sub_questions,
|
||||
)
|
||||
|
||||
# use the retrieved information to generate the answer
|
||||
core_answer_graph.add_edge(
|
||||
start_key=["verifier", "sub_answers_graph_initial"],
|
||||
end_key="generate_initial"
|
||||
)
|
||||
core_answer_graph.add_edge(start_key="generate_initial",
|
||||
end_key="base_wait")
|
||||
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="base_wait",
|
||||
path=continue_to_deep_answer,
|
||||
path_map={"decompose": "entity_term_extraction", "end": "final_stuff"},
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(start_key="entity_term_extraction", end_key="decompose")
|
||||
|
||||
core_answer_graph.add_edge(start_key="decompose",
|
||||
end_key="sub_qa_manager")
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="sub_qa_manager",
|
||||
path=continue_to_answer_sub_questions,
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="sub_answers_graph",
|
||||
end_key="sub_qa_level_aggregator"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="sub_qa_level_aggregator",
|
||||
end_key="deep_answer_generation"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="deep_answer_generation",
|
||||
end_key="final_stuff"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(start_key="final_stuff",
|
||||
end_key=END)
|
||||
|
||||
core_answer_graph.compile()
|
||||
|
||||
return core_answer_graph
|
||||
@@ -0,0 +1,27 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def base_wait(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Ensures that all required steps are completed before proceeding to the next step
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: {} (no operation, just logging)
|
||||
"""
|
||||
|
||||
print("---Base Wait ---")
|
||||
node_start_time = datetime.now()
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="core - base_wait",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def combine_retrieved_docs(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
base_retrieval_docs: Sequence[InferenceSection] = state["base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(base_retrieval_docs)}")
|
||||
dedupe_docs: list[InferenceSection] = []
|
||||
for base_retrieval_doc in base_retrieval_docs:
|
||||
if not any(
|
||||
base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id
|
||||
for doc in dedupe_docs
|
||||
):
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,52 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
retriever_state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE---")
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
query = state["rewritten_query"]
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
top_sections = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
).reranked_sections
|
||||
print(len(top_sections))
|
||||
documents: list[InferenceSection] = []
|
||||
|
||||
return {
|
||||
"base_retrieval_docs": documents,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - custom_retrieve",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,78 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def decompose(state: QAState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_term_extraction_dict = state["retrieved_entities_relationships"][
|
||||
"retrieved_entities_relationships"
|
||||
]
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_term_extraction_dict
|
||||
)
|
||||
|
||||
initial_question_answers = state["initial_sub_qas"]
|
||||
|
||||
addressed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "yes"
|
||||
]
|
||||
failed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "no"
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT.format(
|
||||
question=question,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
sub_questions_dict = {}
|
||||
for sub_question_nr, sub_question_dict in enumerate(
|
||||
parsed_response["sub_questions"]
|
||||
):
|
||||
sub_question_dict["answered"] = False
|
||||
sub_question_dict["verified"] = False
|
||||
sub_questions_dict[sub_question_nr] = sub_question_dict
|
||||
|
||||
return {
|
||||
"decomposed_sub_questions_dict": sub_questions_dict,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - decompose",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
|
||||
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def deep_answer_generation(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---DEEP GENERATE---")
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
deep_answer_context = state["core_answer_dynamic_context"]
|
||||
|
||||
print(f"Number of verified retrieval docs - deep: {len(docs)}")
|
||||
|
||||
combined_context = normalize_whitespace(
|
||||
COMBINED_CONTEXT.format(
|
||||
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=MODIFIED_RAG_PROMPT.format(
|
||||
question=question, combined_context=combined_context
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
return {
|
||||
"deep_answer": response.content,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - deep answer generation",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,11 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
|
||||
|
||||
def dummy_start(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy node to set the start time
|
||||
"""
|
||||
return {"start_time": datetime.now()}
|
||||
@@ -0,0 +1,51 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def entity_term_extraction(state: QAState) -> dict[str, Any]:
|
||||
"""Extract entities and terms from the question and context"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
doc_context = format_docs(docs)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
_, fast_llm = get_default_llms()
|
||||
# Grader
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
|
||||
# structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
return {
|
||||
"retrieved_entities_relationships": parsed_response,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - entity term extraction",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,85 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def final_stuff(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Invokes the agent model to generate a response based on the current state. Given
|
||||
the question, it will decide to retrieve using the retriever tool, or simply end.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the agent response appended to messages
|
||||
"""
|
||||
print("---FINAL---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
messages = state["log_messages"]
|
||||
time_ordered_messages = [x.pretty_repr() for x in messages]
|
||||
time_ordered_messages.sort()
|
||||
|
||||
print("Message Log:")
|
||||
print("\n".join(time_ordered_messages))
|
||||
|
||||
initial_sub_qas = state["initial_sub_qas"]
|
||||
initial_sub_qa_list = []
|
||||
for initial_sub_qa in initial_sub_qas:
|
||||
if initial_sub_qa["sub_answer_check"] == "yes":
|
||||
initial_sub_qa_list.append(
|
||||
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
|
||||
|
||||
log_message = generate_log_message(
|
||||
message="all - final_stuff",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
)
|
||||
|
||||
print(log_message)
|
||||
print("--------------------------------")
|
||||
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
|
||||
print("--------------------------------")
|
||||
|
||||
if not state.get("deep_answer"):
|
||||
print("No Deep Answer was required")
|
||||
return {
|
||||
"log_messages": log_message,
|
||||
}
|
||||
|
||||
deep_answer = state["deep_answer"]
|
||||
sub_qas = state["sub_qas"]
|
||||
sub_qa_list = []
|
||||
for sub_qa in sub_qas:
|
||||
if sub_qa["sub_answer_check"] == "yes":
|
||||
sub_qa_list.append(
|
||||
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
sub_qa_context = "\n".join(sub_qa_list)
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Final Deep Answer:\n{deep_answer}")
|
||||
print("--------------------------------")
|
||||
print("Sub Questions and Answers:")
|
||||
print(sub_qa_context)
|
||||
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="all - final_stuff",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
52
backend/danswer/agent_search/primary_graph/nodes/generate.py
Normal file
52
backend/danswer/agent_search/primary_graph/nodes/generate.py
Normal file
@@ -0,0 +1,52 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def generate(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm = state["fast_llm"]
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"base_answer": response[0].pretty_repr(),
|
||||
"log_messages": generate_log_message(
|
||||
message="core - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,72 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def generate_initial(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE INITIAL---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
print(f"Number of verified retrieval docs - base: {len(docs)}")
|
||||
|
||||
sub_question_answers = state["initial_sub_qas"]
|
||||
|
||||
sub_question_answers_list = []
|
||||
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE = """
|
||||
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
|
||||
"""
|
||||
for sub_question_answer_dict in sub_question_answers:
|
||||
if (
|
||||
sub_question_answer_dict["sub_answer_check"] == "yes"
|
||||
and len(sub_question_answer_dict["sub_answer"]) > 0
|
||||
and sub_question_answer_dict["sub_answer"] != "I don't know"
|
||||
):
|
||||
sub_question_answers_list.append(
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=sub_question_answer_dict["sub_question"],
|
||||
sub_answer=sub_question_answer_dict["sub_answer"],
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_PROMPT.format(
|
||||
question=question,
|
||||
context=format_docs(docs),
|
||||
answered_sub_questions=sub_question_answer_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
return {
|
||||
"base_answer": response.pretty_repr(),
|
||||
"log_messages": generate_log_message(
|
||||
message="core - generate initial",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,64 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def main_decomp_base(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Perform an initial question decomposition, incl. one search term
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with initial decomposition
|
||||
"""
|
||||
|
||||
print("---INITIAL DECOMP---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_DECOMPOSITION_PROMPT.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
content = response.pretty_repr()
|
||||
list_of_subquestions = clean_and_parse_list_string(content)
|
||||
|
||||
decomp_list = []
|
||||
|
||||
for sub_question_nr, sub_question in enumerate(list_of_subquestions):
|
||||
sub_question_str = sub_question["sub_question"].strip()
|
||||
# temporarily
|
||||
sub_question_search_queries = [sub_question["search_term"]]
|
||||
|
||||
decomp_list.append(
|
||||
{
|
||||
"sub_question_str": sub_question_str,
|
||||
"sub_question_search_queries": sub_question_search_queries,
|
||||
"sub_question_nr": sub_question_nr,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"initial_sub_questions": decomp_list,
|
||||
"sub_query_start_time": node_start_time,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - initial decomp",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
55
backend/danswer/agent_search/primary_graph/nodes/rewrite.py
Normal file
55
backend/danswer/agent_search/primary_graph/nodes/rewrite.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def rewrite(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
qa_state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---STARTING GRAPH---")
|
||||
graph_start_time = datetime.now()
|
||||
|
||||
print("---TRANSFORM QUERY---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
fast_llm = state["fast_llm"]
|
||||
llm_response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
|
||||
|
||||
return {
|
||||
"rewritten_queries": formatted_response.rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=graph_start_time,
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]:
|
||||
sub_qas = state["sub_qas"]
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
dynamic_context_list = [
|
||||
"Below you will find useful information to answer the original question:"
|
||||
]
|
||||
checked_sub_qas = []
|
||||
|
||||
for core_answer_sub_qa in sub_qas:
|
||||
question = core_answer_sub_qa["sub_question"]
|
||||
answer = core_answer_sub_qa["sub_answer"]
|
||||
verified = core_answer_sub_qa["sub_answer_check"]
|
||||
|
||||
if verified == "yes":
|
||||
dynamic_context_list.append(
|
||||
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
|
||||
)
|
||||
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
|
||||
dynamic_context = "\n".join(dynamic_context_list)
|
||||
|
||||
return {
|
||||
"core_answer_dynamic_context": dynamic_context,
|
||||
"checked_sub_qas": checked_sub_qas,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - sub qa level aggregator",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_qa_manager(state: QAState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_questions_dict = state["decomposed_sub_questions_dict"]
|
||||
|
||||
sub_questions = {}
|
||||
|
||||
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
|
||||
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
|
||||
|
||||
return {
|
||||
"sub_questions": sub_questions,
|
||||
"num_new_question_iterations": 0,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - sub qa manager",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
59
backend/danswer/agent_search/primary_graph/nodes/verifier.py
Normal file
59
backend/danswer/agent_search/primary_graph/nodes/verifier.py
Normal file
@@ -0,0 +1,59 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---VERIFY QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm = state["fast_llm"]
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"core - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
86
backend/danswer/agent_search/primary_graph/prompts.py
Normal file
86
backend/danswer/agent_search/primary_graph/prompts.py
Normal file
@@ -0,0 +1,86 @@
|
||||
INITIAL_DECOMPOSITION_PROMPT = """ \n
|
||||
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
For each sub-question, please also create one search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of json objects with the following format:
|
||||
|
||||
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
INITIAL_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of answered sub-questions - these are very important(!) and definitely should be
|
||||
considered to answer the question.
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
If you don't know the answer or if the provided information is empty or insufficient, just say
|
||||
"I don't know". Do not use your internal knowledge!
|
||||
|
||||
Again, only use the provided informationand do not use your internal knowledge! It is a matter of life
|
||||
and death that you do NOT use your internal knowledge, just the provided information!
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
And here is the question and the provided information:
|
||||
\n
|
||||
\nQuestion:\n {question}
|
||||
|
||||
\nAnswered Sub-questions:\n {answered_sub_questions}
|
||||
|
||||
\nContext:\n {context} \n\n
|
||||
\n\n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ENTITY_TERM_PROMPT = """ \n
|
||||
Based on the original question and the context retieved from a dataset, please generate a list of
|
||||
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
|
||||
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
|
||||
|
||||
\n\n
|
||||
Here is the original question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
And here is the context retrieved:
|
||||
\n ------- \n
|
||||
{context}
|
||||
\n ------- \n
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
|
||||
{{"retrieved_entities_relationships": {{
|
||||
"entities": [{{
|
||||
"entity_name": <assign a name for the entity>,
|
||||
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
|
||||
}}],
|
||||
"relationships": [{{
|
||||
"name": <assign a name for the relationship>,
|
||||
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
|
||||
"entities": [<related entity name 1>, <related entity name 2>]
|
||||
}}],
|
||||
"terms": [{{
|
||||
"term_name": <assign a name for the term>,
|
||||
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
|
||||
"similar_to": <list terms that are similar to this term>
|
||||
}}]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
73
backend/danswer/agent_search/primary_graph/states.py
Normal file
73
backend/danswer/agent_search/primary_graph/states.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class QAState(TypedDict):
|
||||
# The 'main' state of the answer graph
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
# start time for parallel initial sub-questionn thread
|
||||
sub_query_start_time: datetime
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
rewritten_queries: RewrittenQueries
|
||||
sub_questions: list[dict]
|
||||
initial_sub_questions: list[dict]
|
||||
ranked_subquestion_ids: list[int]
|
||||
decomposed_sub_questions_dict: dict
|
||||
rejected_sub_questions: Annotated[list[str], operator.add]
|
||||
rejected_sub_questions_handled: bool
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
checked_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
base_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
deduped_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
retrieved_entities_relationships: dict
|
||||
questions_context: list[dict]
|
||||
qa_level: int
|
||||
top_chunks: list[InferenceSection]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
num_new_question_iterations: int
|
||||
core_answer_dynamic_context: str
|
||||
dynamic_context: str
|
||||
initial_base_answer: str
|
||||
base_answer: str
|
||||
deep_answer: str
|
||||
|
||||
|
||||
class QAOuputState(TypedDict):
|
||||
# The 'main' output state of the answer graph. Removes all the intermediate states
|
||||
original_question: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_questions: list[dict]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
checked_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
retrieved_entities_relationships: dict
|
||||
top_chunks: list[InferenceSection]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
base_answer: str
|
||||
deep_answer: str
|
||||
|
||||
|
||||
class RetrieverState(TypedDict):
|
||||
# The state for the parallel Retrievers. They each need to see only one query
|
||||
rewritten_query: str
|
||||
graph_start_time: datetime
|
||||
|
||||
|
||||
class VerifierState(TypedDict):
|
||||
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
|
||||
document: InferenceSection
|
||||
question: str
|
||||
graph_start_time: datetime
|
||||
22
backend/danswer/agent_search/run_graph.py
Normal file
22
backend/danswer/agent_search/run_graph.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from danswer.agent_search.primary_graph.graph_builder import build_core_graph
|
||||
from danswer.llm.answering.answer import AnswerStream
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def run_graph(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
tools: list[Tool],
|
||||
) -> AnswerStream:
|
||||
graph = build_core_graph()
|
||||
|
||||
inputs = {
|
||||
"original_question": query,
|
||||
"messages": [],
|
||||
"tools": tools,
|
||||
"llm": llm,
|
||||
}
|
||||
compiled_graph = graph.compile()
|
||||
output = compiled_graph.invoke(input=inputs)
|
||||
yield from output
|
||||
16
backend/danswer/agent_search/shared_graph_utils/models.py
Normal file
16
backend/danswer/agent_search/shared_graph_utils/models.py
Normal file
@@ -0,0 +1,16 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class SubQuestions(BaseModel):
|
||||
sub_questions: list[str]
|
||||
342
backend/danswer/agent_search/shared_graph_utils/prompts.py
Normal file
342
backend/danswer/agent_search/shared_graph_utils/prompts.py
Normal file
@@ -0,0 +1,342 @@
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
|
||||
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
|
||||
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
|
||||
enabling the system to search more broadly.
|
||||
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
|
||||
|
||||
|
||||
REWRITE_PROMPT_MULTI = """ \n
|
||||
Please create a list of 2-3 sample documents that could answer an original question. Each document
|
||||
should be about as long as the original question. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
|
||||
|
||||
BASE_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the context provided below - and only the
|
||||
provided context - to answer the question. If you don't know the answer or if the provided context is
|
||||
empty, just say "I don't know". Do not use your internal knowledge!
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n
|
||||
\n\n
|
||||
Answer:"""
|
||||
|
||||
BASE_CHECK_PROMPT = """ \n
|
||||
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
|
||||
original question requests a simple, factual answer, and there are no ambiguities, judgements,
|
||||
aggregations, or any other complications that may require extra context. (I.e., if the question is
|
||||
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
VERIFIER_PROMPT = """ \n
|
||||
Please check whether the document seems to be relevant for the answer of the original question. Please
|
||||
only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the document text:
|
||||
\n ------- \n
|
||||
{document_content}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
|
||||
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of subquestions:
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_SINGLE = """ \n
|
||||
Please convert an initial user question into a more appropriate search query for retrievel from a
|
||||
document store. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the query: """
|
||||
|
||||
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
|
||||
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
\nQuestion: {question}
|
||||
\nContext: {combined_context} \n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please also provide a search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of json dictionaries with the following format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DECOMPOSE_PROMPT = """ \n
|
||||
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
|
||||
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question.
|
||||
|
||||
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
|
||||
document store, expressed as entities, relationships and terms. You can also think about the types
|
||||
mentioned in brackets
|
||||
|
||||
Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
and or resolve ambiguities
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Don't be too specific unless the original question is specific.
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
#### Consolidations
|
||||
COMBINED_CONTEXT = """-------
|
||||
Below you will find useful information to answer the original question. First, you see a number of
|
||||
sub-questions with their answers. This information should be considered to be more focussed and
|
||||
somewhat more specific to the original question as it tries to contextualized facts.
|
||||
After that will see the documents that were considered to be relevant to answer the original question.
|
||||
|
||||
Here are the sub-questions and their answers:
|
||||
\n\n {deep_answer_context} \n\n
|
||||
\n\n Here are the documents that were considered to be relevant to answer the original question:
|
||||
\n\n {formated_docs} \n\n
|
||||
----------------
|
||||
"""
|
||||
|
||||
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
|
||||
Below you will find a question that we ultimately want to answer (the original question) and a list of
|
||||
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
|
||||
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
|
||||
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
|
||||
motivation and 2 is less relevant.)
|
||||
|
||||
Please rank the motivations in order of relevance for answering the original question. Also, try to
|
||||
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
|
||||
Ultimately, create a list with the motivation numbers where the number of the most relevant
|
||||
motivations comes first.
|
||||
|
||||
Here is the original question:
|
||||
\n\n {original_question} \n\n
|
||||
\n\n Here is the list of sub-question motivations:
|
||||
\n\n {sub_question_explanations} \n\n
|
||||
----------------
|
||||
|
||||
Please think step by step and then generate the ranked list of motivations.
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
{{"reasonning": <explain your reasoning for the ranking>,
|
||||
"ranked_motivations": <ranked list of motivation numbers>}}
|
||||
"""
|
||||
91
backend/danswer/agent_search/shared_graph_utils/utils.py
Normal file
91
backend/danswer/agent_search/shared_graph_utils/utils.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
return "\n\n".join(doc.combined_content for doc in docs)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return ast.literal_eval(cleaned_string)
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
|
||||
entities = entity_term_extraction_dict["entities"]
|
||||
terms = entity_term_extraction_dict["terms"]
|
||||
relationships = entity_term_extraction_dict["relationships"]
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def generate_log_message(
|
||||
message: str,
|
||||
node_start_time: datetime,
|
||||
graph_start_time: datetime | None = None,
|
||||
) -> str:
|
||||
current_time = datetime.now()
|
||||
|
||||
if graph_start_time is not None:
|
||||
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
||||
else:
|
||||
graph_time_str = "N/A"
|
||||
|
||||
node_time_str = _format_time_delta(current_time - node_start_time)
|
||||
|
||||
return f"{graph_time_str} ({node_time_str} s): {message}"
|
||||
@@ -25,6 +25,9 @@ class ToolCallSummary(BaseModel__v1):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def tool_call_tokens(
|
||||
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer
|
||||
|
||||
@@ -26,9 +26,14 @@ huggingface-hub==0.20.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
langchain==0.3.7
|
||||
langchain-core==0.3.20
|
||||
langchain-openai==0.2.9
|
||||
langchain-text-splitters==0.3.2
|
||||
langchainhub==0.1.21
|
||||
langgraph==0.2.53
|
||||
langgraph-checkpoint==2.0.5
|
||||
langgraph-sdk==0.1.36
|
||||
litellm==1.53.1
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
|
||||
Reference in New Issue
Block a user