Compare commits

...

10 Commits

Author SHA1 Message Date
joachim-danswer
0d848fa9dd more experimentation changes 2024-12-13 12:57:58 -08:00
joachim-danswer
1dbb6f3d69 experimentation 2024-12-09 11:14:18 -08:00
joachim-danswer
91cf9a5472 Fit Score & more suitable rewrite 2024-12-08 09:20:03 -08:00
joachim-danswer
ebb0e56a30 initial fixes for core_qa_graph 2024-12-07 22:07:48 -08:00
hagen-danswer
091cb136c4 got core qa graph working 2024-12-07 12:25:54 -08:00
hagen-danswer
56052c5b4b imports 2024-12-07 06:09:57 -08:00
hagen-danswer
617726207b all 3 graphs r done 2024-12-07 06:06:22 -08:00
hagen-danswer
1be58e74b3 Finished primary graph 2024-12-06 11:01:03 -08:00
hagen-danswer
a693c991d7 Merge remote-tracking branch 'origin/agent-search-a' into initial-implementation 2024-12-04 15:42:58 -08:00
hagen-danswer
4b28686721 Added Initial Implementation of the Agent Search Graph 2024-12-02 07:16:08 -08:00
51 changed files with 2917 additions and 3 deletions

View File

@@ -0,0 +1,42 @@
from collections.abc import Hashable
from typing import Union
from langgraph.types import Send
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.primary_graph.states import VerifierState
def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
# Routes each de-douped retrieved doc to the verifier step - in parallel
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_verifier",
VerifierState(
document=doc,
#question=state["original_question"],
question=state["sub_question_str"],
graph_start_time=state["graph_start_time"],
),
)
for doc in state["sub_question_deduped_retrieval_docs"]
]
def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
rewritten_queries = state["sub_question_search_queries"].rewritten_queries + [state["sub_question_str"]]
return [
Send(
"sub_custom_retrieve",
RetrieverState(
rewritten_query=query,
graph_start_time=state["graph_start_time"],
),
)
for query in rewritten_queries
]

View File

@@ -0,0 +1,132 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_retrieval
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_verifier
from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import (
sub_combine_retrieved_docs,
)
from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import (
sub_custom_retrieve,
)
from danswer.agent_search.core_qa_graph.nodes.dummy import sub_dummy
from danswer.agent_search.core_qa_graph.nodes.final_format import (
sub_final_format,
)
from danswer.agent_search.core_qa_graph.nodes.generate import sub_generate
from danswer.agent_search.core_qa_graph.nodes.qa_check import sub_qa_check
from danswer.agent_search.core_qa_graph.nodes.rewrite import sub_rewrite
from danswer.agent_search.core_qa_graph.nodes.verifier import sub_verifier
from danswer.agent_search.core_qa_graph.states import BaseQAOutputState
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.core_qa_graph.states import CoreQAInputState
def build_core_qa_graph() -> StateGraph:
sub_answers_initial = StateGraph(
state_schema=BaseQAState,
output=BaseQAOutputState,
)
### Add nodes ###
sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy)
sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite)
sub_answers_initial.add_node(
node="sub_custom_retrieve",
action=sub_custom_retrieve,
)
sub_answers_initial.add_node(
node="sub_combine_retrieved_docs",
action=sub_combine_retrieved_docs,
)
sub_answers_initial.add_node(
node="sub_verifier",
action=sub_verifier,
)
sub_answers_initial.add_node(
node="sub_generate",
action=sub_generate,
)
sub_answers_initial.add_node(
node="sub_qa_check",
action=sub_qa_check,
)
sub_answers_initial.add_node(
node="sub_final_format",
action=sub_final_format,
)
### Add edges ###
sub_answers_initial.add_edge(START, "sub_dummy")
sub_answers_initial.add_edge("sub_dummy", "sub_rewrite")
sub_answers_initial.add_conditional_edges(
source="sub_rewrite",
path=sub_continue_to_retrieval,
)
sub_answers_initial.add_edge(
start_key="sub_custom_retrieve",
end_key="sub_combine_retrieved_docs",
)
sub_answers_initial.add_conditional_edges(
source="sub_combine_retrieved_docs",
path=sub_continue_to_verifier,
path_map=["sub_verifier"],
)
sub_answers_initial.add_edge(
start_key="sub_verifier",
end_key="sub_generate",
)
sub_answers_initial.add_edge(
start_key="sub_generate",
end_key="sub_qa_check",
)
sub_answers_initial.add_edge(
start_key="sub_qa_check",
end_key="sub_final_format",
)
sub_answers_initial.add_edge(
start_key="sub_final_format",
end_key=END,
)
# sub_answers_graph = sub_answers_initial.compile()
return sub_answers_initial
if __name__ == "__main__":
# q = "Whose music is kind of hard to easily enjoy?"
# q = "What is voice leading?"
# q = "What are the types of motions in music?"
# q = "What are key elements of music theory?"
# q = "How can I best understand music theory using voice leading?"
q = "What makes good music?"
# q = "types of motions in music"
# q = "What is the relationship between music and physics?"
# q = "Can you compare various grunge styles?"
# q = "Why is quantum gravity so hard?"
inputs = CoreQAInputState(
original_question=q,
sub_question_str=q,
)
sub_answers_graph = build_core_qa_graph()
compiled_sub_answers = sub_answers_graph.compile()
output = compiled_sub_answers.invoke(inputs)
print("\nOUTPUT:")
print(output.keys())
for key, value in output.items():
if key in [
"sub_question_answer",
"sub_question_str",
"sub_qas",
"initial_sub_qas",
"sub_question_answer",
]:
print(f"{key}: {value}")

View File

@@ -0,0 +1,36 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
dedupe_docs: list[InferenceSection] = []
for base_retrieval_doc in sub_question_base_retrieval_docs:
if not any(
base_retrieval_doc.center_chunk.chunk_id == doc.center_chunk.chunk_id
for doc in dedupe_docs
):
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"sub_question_deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="sub - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,66 @@
import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE SUB---")
node_start_time = datetime.datetime.now()
rewritten_query = state["rewritten_query"]
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
documents: list[InferenceSection] = []
llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
documents = SearchPipeline(
search_request=SearchRequest(
query=rewritten_query,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
)
reranked_docs = documents.reranked_sections
# initial metric to measure fit TODO: implement metric properly
top_1_score = reranked_docs[0].center_chunk.score
top_5_score = sum([doc.center_chunk.score for doc in reranked_docs[:5]]) / 5
top_10_score = sum([doc.center_chunk.score for doc in reranked_docs[:10]]) / 10
fit_score = 1/3 * (top_1_score + top_5_score + top_10_score)
chunk_ids = {'query': rewritten_query,
'chunk_ids': [doc.center_chunk.chunk_id for doc in reranked_docs]}
return {
"sub_question_base_retrieval_docs": reranked_docs,
"sub_chunk_ids": [chunk_ids],
"log_messages": generate_log_message(
message=f"sub - custom_retrieve, fit_score: {fit_score}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,24 @@
import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
"""
Dummy step
"""
print("---Sub Dummy---")
node_start_time = datetime.datetime.now()
return {
"graph_start_time": node_start_time,
"log_messages": generate_log_message(
message="sub - dummy",
node_start_time=node_start_time,
graph_start_time=node_start_time,
),
}

View File

@@ -0,0 +1,22 @@
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
def sub_final_format(state: BaseQAState) -> dict[str, Any]:
"""
Create the final output for the QA subgraph
"""
print("---BASE FINAL FORMAT---")
return {
"sub_qas": [
{
"sub_question": state["sub_question_str"],
"sub_answer": state["sub_question_answer"],
"sub_answer_check": state["sub_question_answer_check"],
}
],
"log_messages": state["log_messages"],
}

View File

@@ -0,0 +1,91 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_generate(state: BaseQAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE---")
# Create sub-query results
verified_chunks = [chunk.center_chunk.chunk_id for chunk in state["sub_question_verified_retrieval_docs"]]
result_dict = {}
chunk_id_dicts = state["sub_chunk_ids"]
expanded_chunks = []
original_chunks = []
for chunk_id_dict in chunk_id_dicts:
sub_question = chunk_id_dict['query']
verified_sq_chunks = [chunk_id for chunk_id in chunk_id_dict['chunk_ids'] if chunk_id in verified_chunks]
if sub_question != state["original_question"]:
expanded_chunks += verified_sq_chunks
else:
result_dict['ORIGINAL'] = len(verified_sq_chunks)
original_chunks += verified_sq_chunks
result_dict[sub_question[:30]] = len(verified_sq_chunks)
expansion_chunks = set(expanded_chunks)
num_expansion_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id in verified_chunks])
num_original_relevant_chunks = len(original_chunks)
num_missed_relevant_chunks = sum([1 for chunk_id in original_chunks if chunk_id not in expansion_chunks])
num_gained_relevant_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id not in original_chunks])
result_dict['expansion_chunks'] = num_expansion_chunks
print(result_dict)
node_start_time = datetime.now()
question = state["sub_question_str"]
docs = state["sub_question_verified_retrieval_docs"]
print(f"Number of verified retrieval docs: {len(docs)}")
# Only take the top 10 docs.
# TODO: Make this dynamic or use config param?
top_10_docs = docs[-10:]
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(top_10_docs))
)
]
# Grader
_, fast_llm = get_default_llms()
response = list(
fast_llm.stream(
prompt=msg,
# structured_response_format=None,
)
)
answer_str = merge_message_runs(response, chunk_separator="")[0].content
return {
"sub_question_answer": answer_str,
"log_messages": generate_log_message(
message="base - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,51 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_qa_check(state: BaseQAState) -> dict[str, Any]:
"""
Check if the sub-question answer is satisfactory.
Args:
state: The current SubQAState containing the sub-question and its answer
Returns:
dict containing the check result and log message
"""
node_start_time = datetime.datetime.now()
msg = [
HumanMessage(
content=BASE_CHECK_PROMPT.format(
question=state["sub_question_str"],
base_answer=state["sub_question_answer"],
)
)
]
_, fast_llm = get_default_llms()
response = list(
fast_llm.stream(
prompt=msg,
# structured_response_format=None,
)
)
response_str = merge_message_runs(response, chunk_separator="")[0].content
return {
"sub_question_answer_check": response_str,
"base_answer_messages": generate_log_message(
message="sub - qa_check",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,74 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL,
)
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_rewrite(state: BaseQAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB TRANSFORM QUERY---")
node_start_time = datetime.datetime.now()
# messages = state["base_answer_messages"]
question = state["sub_question_str"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
)
]
"""
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
"""
_, fast_llm = get_default_llms()
llm_response_list = list(
fast_llm.stream(
prompt=msg,
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
# structured_response_format=RewrittenQueries.model_json_schema(),
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
print(f"llm_response: {llm_response}")
rewritten_queries = llm_response.split("--")
# rewritten_queries = [llm_response.split("\n")[0]]
print(f"rewritten_queries: {rewritten_queries}")
rewritten_queries = RewrittenQueries(rewritten_queries=rewritten_queries)
return {
"sub_question_search_queries": rewritten_queries,
"log_messages": generate_log_message(
message="sub - rewrite",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,64 @@
import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def sub_verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
# print("---VERIFY QUTPUT---")
node_start_time = datetime.datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
llm, fast_llm = get_default_llms()
response = list(
llm.stream(
prompt=msg,
# structured_response_format=BinaryDecision.model_json_schema(),
)
)
response_string = merge_message_runs(response, chunk_separator="")[0].content
# Convert string response to proper dictionary format
decision_dict = {"decision": response_string.lower()}
formatted_response = BinaryDecision.model_validate(decision_dict)
print(f"Verification end time: {datetime.datetime.now()}")
return {
"sub_question_verified_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"sub - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,90 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLM
class SubQuestionRetrieverState(TypedDict):
# The state for the parallel Retrievers. They each need to see only one query
sub_question_rewritten_query: str
class SubQuestionVerifierState(TypedDict):
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
sub_question_document: InferenceSection
sub_question: str
class CoreQAInputState(TypedDict):
sub_question_str: str
original_question: str
class BaseQAState(TypedDict):
# The 'core SubQuestion' state.
original_question: str
graph_start_time: datetime
# start time for parallel initial sub-questionn thread
sub_query_start_time: datetime
sub_question_rewritten_queries: list[str]
sub_question_str: str
sub_question_search_queries: RewrittenQueries
sub_question_nr: int
sub_chunk_ids: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_qas: Annotated[Sequence[dict], operator.add]
# Answers sent back to core
initial_sub_qas: Annotated[Sequence[dict], operator.add]
primary_llm: LLM
fast_llm: LLM
class BaseQAOutputState(TypedDict):
# The 'SubQuestion' output state. Removes all the intermediate states
sub_question_rewritten_queries: list[str]
sub_question_str: str
sub_question_search_queries: list[str]
sub_question_nr: int
# Answers sent back to core
sub_qas: Annotated[Sequence[dict], operator.add]
# Answers sent back to core
initial_sub_qas: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]

View File

@@ -0,0 +1,46 @@
from collections.abc import Hashable
from typing import Union
from langgraph.types import Send
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.primary_graph.states import VerifierState
def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]:
# Routes each de-douped retrieved doc to the verifier step - in parallel
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_verifier",
VerifierState(
document=doc,
question=state["sub_question"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for doc in state["sub_question_base_retrieval_docs"]
]
def sub_continue_to_retrieval(
state: ResearchQAState,
) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_custom_retrieve",
RetrieverState(
rewritten_query=query,
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for query in state["sub_question_rewritten_queries"]
]

View File

@@ -0,0 +1,93 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_retrieval
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_verifier
from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import (
sub_combine_retrieved_docs,
)
from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import sub_custom_retrieve
from danswer.agent_search.deep_qa_graph.nodes.dummy import sub_dummy
from danswer.agent_search.deep_qa_graph.nodes.final_format import sub_final_format
from danswer.agent_search.deep_qa_graph.nodes.generate import sub_generate
from danswer.agent_search.deep_qa_graph.nodes.qa_check import sub_qa_check
from danswer.agent_search.deep_qa_graph.nodes.verifier import sub_verifier
from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
def build_deep_qa_graph() -> StateGraph:
# Define the nodes we will cycle between
sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState)
### Add Nodes ###
# Dummy node for initial processing
sub_answers.add_node(node="sub_dummy", action=sub_dummy)
# The retrieval step
sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve)
# The dedupe step
sub_answers.add_node(
node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs
)
# Verifying retrieved information
sub_answers.add_node(node="sub_verifier", action=sub_verifier)
# Generating the response
sub_answers.add_node(node="sub_generate", action=sub_generate)
# Checking the quality of the answer
sub_answers.add_node(node="sub_qa_check", action=sub_qa_check)
# Final formatting of the response
sub_answers.add_node(node="sub_final_format", action=sub_final_format)
### Add Edges ###
# Generate multiple sub-questions
sub_answers.add_edge(start_key=START, end_key="sub_rewrite")
# For each sub-question, perform a retrieval in parallel
sub_answers.add_conditional_edges(
source="sub_rewrite",
path=sub_continue_to_retrieval,
path_map=["sub_custom_retrieve"],
)
# Combine the retrieved docs for each sub-question from the parallel retrievals
sub_answers.add_edge(
start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs"
)
# Go over all of the combined retrieved docs and verify them against the original question
sub_answers.add_conditional_edges(
source="sub_combine_retrieved_docs",
path=sub_continue_to_verifier,
path_map=["sub_verifier"],
)
# Generate an answer for each verified retrieved doc
sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate")
# Check the quality of the answer
sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check")
sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format")
sub_answers.add_edge(start_key="sub_final_format", end_key=END)
return sub_answers
if __name__ == "__main__":
# TODO: add the actual question
inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"}
sub_answers_graph = build_deep_qa_graph()
compiled_sub_answers = sub_answers_graph.compile()
output = compiled_sub_answers.invoke(inputs)
print("\nOUTPUT:")
print(output)

View File

@@ -0,0 +1,31 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
dedupe_docs = []
for base_retrieval_doc in sub_question_base_retrieval_docs:
if base_retrieval_doc not in dedupe_docs:
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"sub_question_deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="sub - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,33 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE SUB---")
node_start_time = datetime.now()
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
documents: list[InferenceSection] = []
return {
"sub_question_base_retrieval_docs": documents,
"log_messages": generate_log_message(
message="sub - custom_retrieve",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,21 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
"""
Dummy step
"""
print("---Sub Dummy---")
return {
"log_messages": generate_log_message(
message="sub - dummy",
node_start_time=datetime.now(),
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,31 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_final_format(state: ResearchQAState) -> dict[str, Any]:
"""
Create the final output for the QA subgraph
"""
print("---SUB FINAL FORMAT---")
node_start_time = datetime.now()
return {
# TODO: Type this
"sub_qas": [
{
"sub_question": state["sub_question"],
"sub_answer": state["sub_question_answer"],
"sub_question_nr": state["sub_question_nr"],
"sub_answer_check": state["sub_question_answer_check"],
}
],
"log_messages": generate_log_message(
message="sub - final format",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,56 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_generate(state: ResearchQAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB GENERATE---")
node_start_time = datetime.now()
question = state["sub_question"]
docs = state["sub_question_verified_retrieval_docs"]
print(f"Number of verified retrieval docs for sub-question: {len(docs)}")
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
)
]
# Grader
if len(docs) > 0:
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
)
)
response_str = merge_message_runs(response, chunk_separator="")[0].content
else:
response_str = ""
return {
"sub_question_answer": response_str,
"log_messages": generate_log_message(
message="sub - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,57 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.deep_qa_graph.prompts import SUB_CHECK_PROMPT
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_qa_check(state: ResearchQAState) -> dict[str, Any]:
"""
Check whether the final output satisfies the original user question
Args:
state (messages): The current state
Returns:
dict: The updated state with the final decision
"""
print("---CHECK SUB QUTPUT---")
node_start_time = datetime.now()
sub_answer = state["sub_question_answer"]
sub_question = state["sub_question"]
msg = [
HumanMessage(
content=SUB_CHECK_PROMPT.format(
sub_question=sub_question, sub_answer=sub_answer
)
)
]
# Grader
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"sub_question_answer_check": formatted_response.decision,
"log_messages": generate_log_message(
message=f"sub - qa check: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,64 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.interfaces import LLM
def sub_rewrite(state: ResearchQAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---SUB TRANSFORM QUERY---")
node_start_time = datetime.now()
question = state["sub_question"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
fast_llm: LLM = state["fast_llm"]
llm_response = list(
fast_llm.stream(
prompt=msg,
structured_response_format=RewrittenQueries.model_json_schema(),
)
)
# Get the rewritten queries in a defined format
rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
print(f"rewritten_queries: {rewritten_queries}")
rewritten_queries = RewrittenQueries(
rewritten_queries=[
"music hard to listen to",
"Music that is not fun or pleasant",
]
)
print(f"hardcoded rewritten_queries: {rewritten_queries}")
return {
"sub_question_rewritten_queries": rewritten_queries,
"log_messages": generate_log_message(
message="sub - rewrite",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,59 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
print("---SUB VERIFY QUTPUT---")
node_start_time = datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
model = state["fast_llm"]
response = list(
model.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"deduped_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"core - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,13 @@
SUB_CHECK_PROMPT = """ \n
Please check whether the suggested answer seems to address the original question.
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{base_answer}
\n ------- \n
Please answer with yes or no:"""

View File

@@ -0,0 +1,64 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.context.search.models import InferenceSection
from danswer.llm.interfaces import LLM
class ResearchQAState(TypedDict):
# The 'core SubQuestion' state.
original_question: str
graph_start_time: datetime
sub_question_rewritten_queries: list[str]
sub_question: str
sub_question_nr: int
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_qas: Annotated[Sequence[dict], operator.add]
primary_llm: LLM
fast_llm: LLM
class ResearchQAOutputState(TypedDict):
# The 'SubQuestion' output state. Removes all the intermediate states
sub_question_rewritten_queries: list[str]
sub_question: str
sub_question_nr: int
# Answers sent back to core
sub_qas: Annotated[Sequence[dict], operator.add]
sub_question_base_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_deduped_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_verified_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_reranked_retrieval_docs: Annotated[
Sequence[InferenceSection], operator.add
]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
sub_question_answer: str
sub_question_answer_check: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]

View File

@@ -0,0 +1,75 @@
from collections.abc import Hashable
from typing import Union
from langchain_core.messages import HumanMessage
from langgraph.types import Send
from danswer.agent_search.core_qa_graph.states import BaseQAState
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
def continue_to_initial_sub_questions(
state: QAState,
) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_answers_graph_initial",
BaseQAState(
sub_question_str=initial_sub_question["sub_question_str"],
sub_question_search_queries=initial_sub_question[
"sub_question_search_queries"
],
sub_question_nr=initial_sub_question["sub_question_nr"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
graph_start_time=state["graph_start_time"],
),
)
for initial_sub_question in state["initial_sub_questions"]
]
def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
# Routes re-written queries to the (parallel) retrieval steps
# Notice the 'Send()' API that takes care of the parallelization
return [
Send(
"sub_answers_graph",
ResearchQAState(
sub_question=sub_question["sub_question_str"],
sub_question_nr=sub_question["sub_question_nr"],
graph_start_time=state["graph_start_time"],
primary_llm=state["primary_llm"],
fast_llm=state["fast_llm"],
),
)
for sub_question in state["sub_questions"]
]
def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
print("---GO TO DEEP ANSWER OR END---")
base_answer = state["base_answer"]
question = state["original_question"]
BASE_CHECK_MESSAGE = [
HumanMessage(
content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
)
]
model = state["fast_llm"]
response = model.invoke(BASE_CHECK_MESSAGE)
print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
if response.pretty_repr() == "no":
return "decompose"
else:
return "end"

View File

@@ -0,0 +1,171 @@
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.core_qa_graph.graph_builder import build_core_qa_graph
from danswer.agent_search.deep_qa_graph.graph_builder import build_deep_qa_graph
from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions
from danswer.agent_search.primary_graph.edges import continue_to_deep_answer
from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions
from danswer.agent_search.primary_graph.nodes.base_wait import base_wait
from danswer.agent_search.primary_graph.nodes.combine_retrieved_docs import (
combine_retrieved_docs,
)
from danswer.agent_search.primary_graph.nodes.custom_retrieve import custom_retrieve
from danswer.agent_search.primary_graph.nodes.decompose import decompose
from danswer.agent_search.primary_graph.nodes.deep_answer_generation import (
deep_answer_generation,
)
from danswer.agent_search.primary_graph.nodes.dummy_start import dummy_start
from danswer.agent_search.primary_graph.nodes.entity_term_extraction import (
entity_term_extraction,
)
from danswer.agent_search.primary_graph.nodes.final_stuff import final_stuff
from danswer.agent_search.primary_graph.nodes.generate_initial import generate_initial
from danswer.agent_search.primary_graph.nodes.main_decomp_base import main_decomp_base
from danswer.agent_search.primary_graph.nodes.rewrite import rewrite
from danswer.agent_search.primary_graph.nodes.sub_qa_level_aggregator import (
sub_qa_level_aggregator,
)
from danswer.agent_search.primary_graph.nodes.sub_qa_manager import sub_qa_manager
from danswer.agent_search.primary_graph.nodes.verifier import verifier
from danswer.agent_search.primary_graph.states import QAState
def build_core_graph() -> StateGraph:
# Define the nodes we will cycle between
core_answer_graph = StateGraph(state_schema=QAState)
### Add Nodes ###
core_answer_graph.add_node(node="dummy_start",
action=dummy_start)
# Re-writing the question
core_answer_graph.add_node(node="rewrite",
action=rewrite)
# The retrieval step
core_answer_graph.add_node(node="custom_retrieve",
action=custom_retrieve)
# Combine and dedupe retrieved docs.
core_answer_graph.add_node(
node="combine_retrieved_docs",
action=combine_retrieved_docs
)
# Extract entities, terms and relationships
core_answer_graph.add_node(
node="entity_term_extraction",
action=entity_term_extraction
)
# Verifying that a retrieved doc is relevant
core_answer_graph.add_node(node="verifier",
action=verifier)
# Initial question decomposition
core_answer_graph.add_node(node="main_decomp_base",
action=main_decomp_base)
# Build the base QA sub-graph and compile it
compiled_core_qa_graph = build_core_qa_graph().compile()
# Add the compiled base QA sub-graph as a node to the core graph
core_answer_graph.add_node(
node="sub_answers_graph_initial",
action=compiled_core_qa_graph
)
# Checking whether the initial answer is in the ballpark
core_answer_graph.add_node(node="base_wait",
action=base_wait)
# Decompose the question into sub-questions
core_answer_graph.add_node(node="decompose",
action=decompose)
# Manage the sub-questions
core_answer_graph.add_node(node="sub_qa_manager",
action=sub_qa_manager)
# Build the research QA sub-graph and compile it
compiled_deep_qa_graph = build_deep_qa_graph().compile()
# Add the compiled research QA sub-graph as a node to the core graph
core_answer_graph.add_node(node="sub_answers_graph",
action=compiled_deep_qa_graph)
# Aggregate the sub-questions
core_answer_graph.add_node(
node="sub_qa_level_aggregator",
action=sub_qa_level_aggregator
)
# aggregate sub questions and answers
core_answer_graph.add_node(
node="deep_answer_generation",
action=deep_answer_generation
)
# A final clean-up step
core_answer_graph.add_node(node="final_stuff",
action=final_stuff)
# Generating a response after we know the documents are relevant
core_answer_graph.add_node(node="generate_initial",
action=generate_initial)
### Add Edges ###
# start the initial sub-question decomposition
core_answer_graph.add_edge(start_key=START,
end_key="main_decomp_base")
core_answer_graph.add_conditional_edges(
source="main_decomp_base",
path=continue_to_initial_sub_questions,
)
# use the retrieved information to generate the answer
core_answer_graph.add_edge(
start_key=["verifier", "sub_answers_graph_initial"],
end_key="generate_initial"
)
core_answer_graph.add_edge(start_key="generate_initial",
end_key="base_wait")
core_answer_graph.add_conditional_edges(
source="base_wait",
path=continue_to_deep_answer,
path_map={"decompose": "entity_term_extraction", "end": "final_stuff"},
)
core_answer_graph.add_edge(start_key="entity_term_extraction", end_key="decompose")
core_answer_graph.add_edge(start_key="decompose",
end_key="sub_qa_manager")
core_answer_graph.add_conditional_edges(
source="sub_qa_manager",
path=continue_to_answer_sub_questions,
)
core_answer_graph.add_edge(
start_key="sub_answers_graph",
end_key="sub_qa_level_aggregator"
)
core_answer_graph.add_edge(
start_key="sub_qa_level_aggregator",
end_key="deep_answer_generation"
)
core_answer_graph.add_edge(
start_key="deep_answer_generation",
end_key="final_stuff"
)
core_answer_graph.add_edge(start_key="final_stuff",
end_key=END)
core_answer_graph.compile()
return core_answer_graph

View File

@@ -0,0 +1,27 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def base_wait(state: QAState) -> dict[str, Any]:
"""
Ensures that all required steps are completed before proceeding to the next step
Args:
state (messages): The current state
Returns:
dict: {} (no operation, just logging)
"""
print("---Base Wait ---")
node_start_time = datetime.now()
return {
"log_messages": generate_log_message(
message="core - base_wait",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,36 @@
from collections.abc import Sequence
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
def combine_retrieved_docs(state: QAState) -> dict[str, Any]:
"""
Dedupe the retrieved docs.
"""
node_start_time = datetime.now()
base_retrieval_docs: Sequence[InferenceSection] = state["base_retrieval_docs"]
print(f"Number of docs from steps: {len(base_retrieval_docs)}")
dedupe_docs: list[InferenceSection] = []
for base_retrieval_doc in base_retrieval_docs:
if not any(
base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id
for doc in dedupe_docs
):
dedupe_docs.append(base_retrieval_doc)
print(f"Number of deduped docs: {len(dedupe_docs)}")
return {
"deduped_retrieval_docs": dedupe_docs,
"log_messages": generate_log_message(
message="core - combine_retrieved_docs (dedupe)",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,52 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import RetrieverState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session_context_manager
from danswer.llm.factory import get_default_llms
def custom_retrieve(state: RetrieverState) -> dict[str, Any]:
"""
Retrieve documents
Args:
retriever_state (dict): The current graph state
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print("---RETRIEVE---")
node_start_time = datetime.now()
query = state["rewritten_query"]
# Retrieval
# TODO: add the actual retrieval, probably from search_tool.run()
llm, fast_llm = get_default_llms()
with get_session_context_manager() as db_session:
top_sections = SearchPipeline(
search_request=SearchRequest(
query=query,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
).reranked_sections
print(len(top_sections))
documents: list[InferenceSection] = []
return {
"base_retrieval_docs": documents,
"log_messages": generate_log_message(
message="core - custom_retrieve",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,78 @@
import json
import re
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def decompose(state: QAState) -> dict[str, Any]:
""" """
node_start_time = datetime.now()
question = state["original_question"]
base_answer = state["base_answer"]
# get the entity term extraction dict and properly format it
entity_term_extraction_dict = state["retrieved_entities_relationships"][
"retrieved_entities_relationships"
]
entity_term_extraction_str = format_entity_term_extraction(
entity_term_extraction_dict
)
initial_question_answers = state["initial_sub_qas"]
addressed_question_list = [
x["sub_question"]
for x in initial_question_answers
if x["sub_answer_check"] == "yes"
]
failed_question_list = [
x["sub_question"]
for x in initial_question_answers
if x["sub_answer_check"] == "no"
]
msg = [
HumanMessage(
content=DEEP_DECOMPOSE_PROMPT.format(
question=question,
entity_term_extraction_str=entity_term_extraction_str,
base_answer=base_answer,
answered_sub_questions="\n - ".join(addressed_question_list),
failed_sub_questions="\n - ".join(failed_question_list),
),
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
parsed_response = json.loads(cleaned_response)
sub_questions_dict = {}
for sub_question_nr, sub_question_dict in enumerate(
parsed_response["sub_questions"]
):
sub_question_dict["answered"] = False
sub_question_dict["verified"] = False
sub_questions_dict[sub_question_nr] = sub_question_dict
return {
"decomposed_sub_questions_dict": sub_questions_dict,
"log_messages": generate_log_message(
message="deep - decompose",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,61 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace
# aggregate sub questions and answers
def deep_answer_generation(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---DEEP GENERATE---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
deep_answer_context = state["core_answer_dynamic_context"]
print(f"Number of verified retrieval docs - deep: {len(docs)}")
combined_context = normalize_whitespace(
COMBINED_CONTEXT.format(
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
)
)
msg = [
HumanMessage(
content=MODIFIED_RAG_PROMPT.format(
question=question, combined_context=combined_context
)
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
return {
"deep_answer": response.content,
"log_messages": generate_log_message(
message="deep - deep answer generation",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,11 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
def dummy_start(state: QAState) -> dict[str, Any]:
"""
Dummy node to set the start time
"""
return {"start_time": datetime.now()}

View File

@@ -0,0 +1,51 @@
import json
import re
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
from danswer.llm.factory import get_default_llms
def entity_term_extraction(state: QAState) -> dict[str, Any]:
"""Extract entities and terms from the question and context"""
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
doc_context = format_docs(docs)
msg = [
HumanMessage(
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
)
]
_, fast_llm = get_default_llms()
# Grader
llm_response_list = list(
fast_llm.stream(
prompt=msg,
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
# structured_response_format=RewrittenQueries.model_json_schema(),
)
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
parsed_response = json.loads(cleaned_response)
return {
"retrieved_entities_relationships": parsed_response,
"log_messages": generate_log_message(
message="deep - entity term extraction",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,85 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def final_stuff(state: QAState) -> dict[str, Any]:
"""
Invokes the agent model to generate a response based on the current state. Given
the question, it will decide to retrieve using the retriever tool, or simply end.
Args:
state (messages): The current state
Returns:
dict: The updated state with the agent response appended to messages
"""
print("---FINAL---")
node_start_time = datetime.now()
messages = state["log_messages"]
time_ordered_messages = [x.pretty_repr() for x in messages]
time_ordered_messages.sort()
print("Message Log:")
print("\n".join(time_ordered_messages))
initial_sub_qas = state["initial_sub_qas"]
initial_sub_qa_list = []
for initial_sub_qa in initial_sub_qas:
if initial_sub_qa["sub_answer_check"] == "yes":
initial_sub_qa_list.append(
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
)
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
log_message = generate_log_message(
message="all - final_stuff",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
)
print(log_message)
print("--------------------------------")
base_answer = state["base_answer"]
print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
print("--------------------------------")
if not state.get("deep_answer"):
print("No Deep Answer was required")
return {
"log_messages": log_message,
}
deep_answer = state["deep_answer"]
sub_qas = state["sub_qas"]
sub_qa_list = []
for sub_qa in sub_qas:
if sub_qa["sub_answer_check"] == "yes":
sub_qa_list.append(
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
)
sub_qa_context = "\n".join(sub_qa_list)
print(f"Final Base Answer:\n{base_answer}")
print("--------------------------------")
print(f"Final Deep Answer:\n{deep_answer}")
print("--------------------------------")
print("Sub Questions and Answers:")
print(sub_qa_context)
return {
"log_messages": generate_log_message(
message="all - final_stuff",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,52 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def generate(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
print(f"Number of verified retrieval docs: {len(docs)}")
msg = [
HumanMessage(
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
)
]
# Grader
llm = state["fast_llm"]
response = list(
llm.stream(
prompt=msg,
structured_response_format=None,
)
)
return {
"base_answer": response[0].pretty_repr(),
"log_messages": generate_log_message(
message="core - generate",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,72 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import format_docs
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def generate_initial(state: QAState) -> dict[str, Any]:
"""
Generate answer
Args:
state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---GENERATE INITIAL---")
node_start_time = datetime.now()
question = state["original_question"]
docs = state["deduped_retrieval_docs"]
print(f"Number of verified retrieval docs - base: {len(docs)}")
sub_question_answers = state["initial_sub_qas"]
sub_question_answers_list = []
_SUB_QUESTION_ANSWER_TEMPLATE = """
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
"""
for sub_question_answer_dict in sub_question_answers:
if (
sub_question_answer_dict["sub_answer_check"] == "yes"
and len(sub_question_answer_dict["sub_answer"]) > 0
and sub_question_answer_dict["sub_answer"] != "I don't know"
):
sub_question_answers_list.append(
_SUB_QUESTION_ANSWER_TEMPLATE.format(
sub_question=sub_question_answer_dict["sub_question"],
sub_answer=sub_question_answer_dict["sub_answer"],
)
)
sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list)
msg = [
HumanMessage(
content=INITIAL_RAG_PROMPT.format(
question=question,
context=format_docs(docs),
answered_sub_questions=sub_question_answer_str,
)
)
]
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
return {
"base_answer": response.pretty_repr(),
"log_messages": generate_log_message(
message="core - generate initial",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,64 @@
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def main_decomp_base(state: QAState) -> dict[str, Any]:
"""
Perform an initial question decomposition, incl. one search term
Args:
state (messages): The current state
Returns:
dict: The updated state with initial decomposition
"""
print("---INITIAL DECOMP---")
node_start_time = datetime.now()
question = state["original_question"]
msg = [
HumanMessage(
content=INITIAL_DECOMPOSITION_PROMPT.format(question=question),
)
]
# Get the rewritten queries in a defined format
model = state["fast_llm"]
response = model.invoke(msg)
content = response.pretty_repr()
list_of_subquestions = clean_and_parse_list_string(content)
decomp_list = []
for sub_question_nr, sub_question in enumerate(list_of_subquestions):
sub_question_str = sub_question["sub_question"].strip()
# temporarily
sub_question_search_queries = [sub_question["search_term"]]
decomp_list.append(
{
"sub_question_str": sub_question_str,
"sub_question_search_queries": sub_question_search_queries,
"sub_question_nr": sub_question_nr,
}
)
return {
"initial_sub_questions": decomp_list,
"sub_query_start_time": node_start_time,
"log_messages": generate_log_message(
message="core - initial decomp",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,55 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def rewrite(state: QAState) -> dict[str, Any]:
"""
Transform the initial question into more suitable search queries.
Args:
qa_state (messages): The current state
Returns:
dict: The updated state with re-phrased question
"""
print("---STARTING GRAPH---")
graph_start_time = datetime.now()
print("---TRANSFORM QUERY---")
node_start_time = datetime.now()
question = state["original_question"]
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
# Get the rewritten queries in a defined format
fast_llm = state["fast_llm"]
llm_response = list(
fast_llm.stream(
prompt=msg,
structured_response_format=RewrittenQueries.model_json_schema(),
)
)
formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
return {
"rewritten_queries": formatted_response.rewritten_queries,
"log_messages": generate_log_message(
message="core - rewrite",
node_start_time=node_start_time,
graph_start_time=graph_start_time,
),
}

View File

@@ -0,0 +1,39 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
# aggregate sub questions and answers
def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]:
sub_qas = state["sub_qas"]
node_start_time = datetime.now()
dynamic_context_list = [
"Below you will find useful information to answer the original question:"
]
checked_sub_qas = []
for core_answer_sub_qa in sub_qas:
question = core_answer_sub_qa["sub_question"]
answer = core_answer_sub_qa["sub_answer"]
verified = core_answer_sub_qa["sub_answer_check"]
if verified == "yes":
dynamic_context_list.append(
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
)
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
dynamic_context = "\n".join(dynamic_context_list)
return {
"core_answer_dynamic_context": dynamic_context,
"checked_sub_qas": checked_sub_qas,
"log_messages": generate_log_message(
message="deep - sub qa level aggregator",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,28 @@
from datetime import datetime
from typing import Any
from danswer.agent_search.primary_graph.states import QAState
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def sub_qa_manager(state: QAState) -> dict[str, Any]:
""" """
node_start_time = datetime.now()
sub_questions_dict = state["decomposed_sub_questions_dict"]
sub_questions = {}
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
return {
"sub_questions": sub_questions,
"num_new_question_iterations": 0,
"log_messages": generate_log_message(
message="deep - sub qa manager",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,59 @@
import json
from datetime import datetime
from typing import Any
from langchain_core.messages import HumanMessage
from danswer.agent_search.primary_graph.states import VerifierState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
def verifier(state: VerifierState) -> dict[str, Any]:
"""
Check whether the document is relevant for the original user question
Args:
state (VerifierState): The current state
Returns:
dict: ict: The updated state with the final decision
"""
print("---VERIFY QUTPUT---")
node_start_time = datetime.now()
question = state["question"]
document_content = state["document"].combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=question, document_content=document_content
)
)
]
# Grader
llm = state["fast_llm"]
response = list(
llm.stream(
prompt=msg,
structured_response_format=BinaryDecision.model_json_schema(),
)
)
raw_response = json.loads(response[0].pretty_repr())
formatted_response = BinaryDecision.model_validate(raw_response)
return {
"deduped_retrieval_docs": [state["document"]]
if formatted_response.decision == "yes"
else [],
"log_messages": generate_log_message(
message=f"core - verifier: {formatted_response.decision}",
node_start_time=node_start_time,
graph_start_time=state["graph_start_time"],
),
}

View File

@@ -0,0 +1,86 @@
INITIAL_DECOMPOSITION_PROMPT = """ \n
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
For each sub-question, please also create one search term that can be used to retrieve relevant
documents from a document store.
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of json objects with the following format:
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
Answer:
"""
INITIAL_RAG_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the information provided below - and only the
provided information - to answer the provided question.
The information provided below consists of:
1) a number of answered sub-questions - these are very important(!) and definitely should be
considered to answer the question.
2) a number of documents that were also deemed relevant for the question.
If you don't know the answer or if the provided information is empty or insufficient, just say
"I don't know". Do not use your internal knowledge!
Again, only use the provided informationand do not use your internal knowledge! It is a matter of life
and death that you do NOT use your internal knowledge, just the provided information!
Try to keep your answer concise.
And here is the question and the provided information:
\n
\nQuestion:\n {question}
\nAnswered Sub-questions:\n {answered_sub_questions}
\nContext:\n {context} \n\n
\n\n
Answer:"""
ENTITY_TERM_PROMPT = """ \n
Based on the original question and the context retieved from a dataset, please generate a list of
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
\n\n
Here is the original question:
\n ------- \n
{question}
\n ------- \n
And here is the context retrieved:
\n ------- \n
{context}
\n ------- \n
Please format your answer as a json object in the following format:
{{"retrieved_entities_relationships": {{
"entities": [{{
"entity_name": <assign a name for the entity>,
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
}}],
"relationships": [{{
"name": <assign a name for the relationship>,
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
"entities": [<related entity name 1>, <related entity name 2>]
}}],
"terms": [{{
"term_name": <assign a name for the term>,
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
"similar_to": <list terms that are similar to this term>
}}]
}}
}}
"""

View File

@@ -0,0 +1,73 @@
import operator
from collections.abc import Sequence
from datetime import datetime
from typing import Annotated
from typing import TypedDict
from langchain_core.messages import BaseMessage
from langgraph.graph.message import add_messages
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
from danswer.context.search.models import InferenceSection
class QAState(TypedDict):
# The 'main' state of the answer graph
original_question: str
graph_start_time: datetime
# start time for parallel initial sub-questionn thread
sub_query_start_time: datetime
log_messages: Annotated[Sequence[BaseMessage], add_messages]
rewritten_queries: RewrittenQueries
sub_questions: list[dict]
initial_sub_questions: list[dict]
ranked_subquestion_ids: list[int]
decomposed_sub_questions_dict: dict
rejected_sub_questions: Annotated[list[str], operator.add]
rejected_sub_questions_handled: bool
sub_qas: Annotated[Sequence[dict], operator.add]
initial_sub_qas: Annotated[Sequence[dict], operator.add]
checked_sub_qas: Annotated[Sequence[dict], operator.add]
base_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
deduped_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
retrieved_entities_relationships: dict
questions_context: list[dict]
qa_level: int
top_chunks: list[InferenceSection]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
num_new_question_iterations: int
core_answer_dynamic_context: str
dynamic_context: str
initial_base_answer: str
base_answer: str
deep_answer: str
class QAOuputState(TypedDict):
# The 'main' output state of the answer graph. Removes all the intermediate states
original_question: str
log_messages: Annotated[Sequence[BaseMessage], add_messages]
sub_questions: list[dict]
sub_qas: Annotated[Sequence[dict], operator.add]
initial_sub_qas: Annotated[Sequence[dict], operator.add]
checked_sub_qas: Annotated[Sequence[dict], operator.add]
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
retrieved_entities_relationships: dict
top_chunks: list[InferenceSection]
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
base_answer: str
deep_answer: str
class RetrieverState(TypedDict):
# The state for the parallel Retrievers. They each need to see only one query
rewritten_query: str
graph_start_time: datetime
class VerifierState(TypedDict):
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
document: InferenceSection
question: str
graph_start_time: datetime

View File

@@ -0,0 +1,22 @@
from danswer.agent_search.primary_graph.graph_builder import build_core_graph
from danswer.llm.answering.answer import AnswerStream
from danswer.llm.interfaces import LLM
from danswer.tools.tool import Tool
def run_graph(
query: str,
llm: LLM,
tools: list[Tool],
) -> AnswerStream:
graph = build_core_graph()
inputs = {
"original_question": query,
"messages": [],
"tools": tools,
"llm": llm,
}
compiled_graph = graph.compile()
output = compiled_graph.invoke(input=inputs)
yield from output

View File

@@ -0,0 +1,16 @@
from typing import Literal
from pydantic import BaseModel
# Pydantic models for structured outputs
class RewrittenQueries(BaseModel):
rewritten_queries: list[str]
class BinaryDecision(BaseModel):
decision: Literal["yes", "no"]
class SubQuestions(BaseModel):
sub_questions: list[str]

View File

@@ -0,0 +1,342 @@
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
enabling the system to search more broadly.
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
REWRITE_PROMPT_MULTI = """ \n
Please create a list of 2-3 sample documents that could answer an original question. Each document
should be about as long as the original question. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
BASE_RAG_PROMPT = """ \n
You are an assistant for question-answering tasks. Use the context provided below - and only the
provided context - to answer the question. If you don't know the answer or if the provided context is
empty, just say "I don't know". Do not use your internal knowledge!
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
Use three sentences maximum and keep the answer concise.
answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n
\n\n
Answer:"""
BASE_CHECK_PROMPT = """ \n
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
original question requests a simple, factual answer, and there are no ambiguities, judgements,
aggregations, or any other complications that may require extra context. (I.e., if the question is
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
Please only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the proposed answer:
\n ------- \n
{base_answer}
\n ------- \n
Please answer with yes or no:"""
VERIFIER_PROMPT = """ \n
Please check whether the document seems to be relevant for the answer of the original question. Please
only answer with 'yes' or 'no' \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the document text:
\n ------- \n
{document_content}
\n ------- \n
Please answer with yes or no:"""
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
answer the original question. The purpose for this decomposition is to isolate individulal entities
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Please formulate your answer as a list of subquestions:
Answer:
"""
REWRITE_PROMPT_SINGLE = """ \n
Please convert an initial user question into a more appropriate search query for retrievel from a
document store. \n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Formulate the query: """
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
Use three sentences maximum and keep the answer concise.
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
use your internal knowledge, just the provided information!
\nQuestion: {question}
\nContext: {combined_context} \n
Answer:"""
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DEEP_DECOMPOSE_PROMPT = """ \n
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
good enough. Also, some sub-questions had been answered and this information has been used to provide
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
you have an idea of how the avaiolable data looks like.
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
considering:
1) The initial question
2) The initial answer that was found to be unsatisfactory
3) The sub-questions that were answered
4) The sub-questions that were suggested but not answered
5) The entities, relationships and terms that were extracted from the context
The individual questions should be answerable by a good RAG system.
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question, but in a way that does
not duplicate questions that were already tried.
Additional Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
resolve ambiguities, or address shortcoming of the initial answer
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please also provide a search term that can be used to retrieve relevant
documents from a document store.
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
Here is the initial sub-optimal answer:
\n ------- \n
{base_answer}
\n ------- \n
Here are the sub-questions that were answered:
\n ------- \n
{answered_sub_questions}
\n ------- \n
Here are the sub-questions that were suggested but not answered:
\n ------- \n
{failed_sub_questions}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Again, please find questions that are NOT overlapping too much with the already answered
sub-questions or those that already were suggested and failed.
In other words - what can we try in addition to what has been tried so far?
Generate the list of json dictionaries with the following format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
DECOMPOSE_PROMPT = """ \n
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
question for different entities that may be involved in the original question.
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
document store, expressed as entities, relationships and terms. You can also think about the types
mentioned in brackets
Guidelines:
- The sub-questions should be specific to the question and provide richer context for the question,
and or resolve ambiguities
- Each sub-question - when answered - should be relevant for the answer to the original question
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
other complications that may require extra context.
- The sub-questions MUST have the full context of the original question so that it can be executed by
a RAG system independently without the original question available
(Example:
- initial question: "What is the capital of France?"
- bad sub-question: "What is the name of the river there?"
- good sub-question: "What is the name of the river that flows through Paris?"
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
generate a list of dictionaries with the following format:
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
sub-question using as a search phrase for the document store>}}, ...]
\n\n
Here is the initial question:
\n ------- \n
{question}
\n ------- \n
And here are the entities, relationships and terms extracted from the context:
\n ------- \n
{entity_term_extraction_str}
\n ------- \n
Please generate the list of good, fully contextualized sub-questions that would help to address the
main question. Don't be too specific unless the original question is specific.
Please think through it step by step and then generate the list of json dictionaries with the following
format:
{{"sub_questions": [{{"sub_question": <sub-question>,
"explanation": <explanation>,
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
...]}} """
#### Consolidations
COMBINED_CONTEXT = """-------
Below you will find useful information to answer the original question. First, you see a number of
sub-questions with their answers. This information should be considered to be more focussed and
somewhat more specific to the original question as it tries to contextualized facts.
After that will see the documents that were considered to be relevant to answer the original question.
Here are the sub-questions and their answers:
\n\n {deep_answer_context} \n\n
\n\n Here are the documents that were considered to be relevant to answer the original question:
\n\n {formated_docs} \n\n
----------------
"""
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
Below you will find a question that we ultimately want to answer (the original question) and a list of
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
motivation and 2 is less relevant.)
Please rank the motivations in order of relevance for answering the original question. Also, try to
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
Ultimately, create a list with the motivation numbers where the number of the most relevant
motivations comes first.
Here is the original question:
\n\n {original_question} \n\n
\n\n Here is the list of sub-question motivations:
\n\n {sub_question_explanations} \n\n
----------------
Please think step by step and then generate the ranked list of motivations.
Please format your answer as a json object in the following format:
{{"reasonning": <explain your reasoning for the ranking>,
"ranked_motivations": <ranked list of motivation numbers>}}
"""

View File

@@ -0,0 +1,91 @@
import ast
import json
import re
from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from typing import Any
from danswer.context.search.models import InferenceSection
def normalize_whitespace(text: str) -> str:
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
import re
return re.sub(r"\s+", " ", text.strip())
# Post-processing
def format_docs(docs: Sequence[InferenceSection]) -> str:
return "\n\n".join(doc.combined_content for doc in docs)
def clean_and_parse_list_string(json_string: str) -> list[dict]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return ast.literal_eval(cleaned_string)
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
# Remove markdown code block markers and any newline prefixes
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
cleaned_string = " ".join(cleaned_string.split())
# Parse the cleaned string into a Python dictionary
return json.loads(cleaned_string)
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
entities = entity_term_extraction_dict["entities"]
terms = entity_term_extraction_dict["terms"]
relationships = entity_term_extraction_dict["relationships"]
entity_strs = ["\nEntities:\n"]
for entity in entities:
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
entity_strs.append(entity_str)
entity_str = "\n - ".join(entity_strs)
relationship_strs = ["\n\nRelationships:\n"]
for relationship in relationships:
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
relationship_strs.append(relationship_str)
relationship_str = "\n - ".join(relationship_strs)
term_strs = ["\n\nTerms:\n"]
for term in terms:
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
term_strs.append(term_str)
term_str = "\n - ".join(term_strs)
return "\n".join(entity_strs + relationship_strs + term_strs)
def _format_time_delta(time: timedelta) -> str:
seconds_from_start = f"{((time).seconds):03d}"
microseconds_from_start = f"{((time).microseconds):06d}"
return f"{seconds_from_start}.{microseconds_from_start}"
def generate_log_message(
message: str,
node_start_time: datetime,
graph_start_time: datetime | None = None,
) -> str:
current_time = datetime.now()
if graph_start_time is not None:
graph_time_str = _format_time_delta(current_time - graph_start_time)
else:
graph_time_str = "N/A"
node_time_str = _format_time_delta(current_time - node_start_time)
return f"{graph_time_str} ({node_time_str} s): {message}"

View File

@@ -25,6 +25,9 @@ class ToolCallSummary(BaseModel__v1):
tool_call_request: AIMessage
tool_call_result: ToolMessage
class Config:
arbitrary_types_allowed = True
def tool_call_tokens(
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer

View File

@@ -26,9 +26,14 @@ huggingface-hub==0.20.1
jira==3.5.1
jsonref==1.1.0
trafilatura==1.12.2
langchain==0.1.17
langchain-core==0.1.50
langchain-text-splitters==0.0.1
langchain==0.3.7
langchain-core==0.3.20
langchain-openai==0.2.9
langchain-text-splitters==0.3.2
langchainhub==0.1.21
langgraph==0.2.53
langgraph-checkpoint==2.0.5
langgraph-sdk==0.1.36
litellm==1.53.1
lxml==5.3.0
lxml_html_clean==0.2.2