Compare commits

..

2 Commits

Author SHA1 Message Date
joachim-danswer
0f83ae16b8 tmp 2024-12-15 16:52:38 -08:00
joachim-danswer
ba54097069 initial adjustments 2024-12-14 09:44:44 -08:00
15 changed files with 224 additions and 53 deletions

View File

@@ -9,7 +9,7 @@ from danswer.agent_search.shared_graph_utils.utils import format_docs
def answer_generation(state: AnswerQueryState) -> QAGenerationOutput:
query = state["query_to_answer"]
docs = state["reordered_documents"]
docs = state["reranked_documents"]
print(f"Number of verified retrieval docs: {len(docs)}")

View File

@@ -10,7 +10,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput:
query=state["query_to_answer"],
quality=state["answer_quality"],
answer=state["answer"],
documents=state["reordered_documents"],
documents=state["reranked_documents"],
)
],
)

View File

@@ -24,7 +24,7 @@ class QAGenerationOutput(TypedDict, total=False):
class ExpandedRetrievalOutput(TypedDict):
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class AnswerQueryState(
@@ -35,6 +35,9 @@ class AnswerQueryState(
total=True,
):
query_to_answer: str
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
class AnswerQueryInput(PrimaryState, total=True):

View File

@@ -5,23 +5,36 @@ from langchain_core.messages import merge_message_runs
from langgraph.types import Send
from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput
from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput
from danswer.agent_search.expanded_retrieval.states import DocVerificationInput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
from danswer.agent_search.shared_graph_utils.prompts import (
REWRITE_PROMPT_MULTI_ORIGINAL,
)
from danswer.llm.interfaces import LLM
def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]:
print(f"parallel_retrieval_edge state: {state.keys()}")
# print(f"parallel_retrieval_edge state: {state.keys()}")
print("parallel_retrieval_edge state")
# This should be better...
question = state.get("query_to_answer") or state["search_request"].query
llm: LLM = state["fast_llm"]
"""
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI.format(question=question),
)
]
"""
msg = [
HumanMessage(
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
)
]
llm_response_list = list(
llm.stream(
prompt=msg,
@@ -29,9 +42,14 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab
)
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
print(f"llm_response: {llm_response}")
# print(f"llm_response: {llm_response}")
rewritten_queries = llm_response.split("\n")
rewritten_queries = [
rewritten_query.strip() for rewritten_query in llm_response.split("--")
]
# Add the original sub-question as one of the 'rewritten' queries
rewritten_queries = [question] + rewritten_queries
print(f"rewritten_queries: {rewritten_queries}")
@@ -42,3 +60,24 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab
)
for query in rewritten_queries
]
def parallel_verification_edge(state: DocRetrievalOutput) -> list[Send | Hashable]:
# print(f"parallel_retrieval_edge state: {state.keys()}")
print("parallel_retrieval_edge state")
retrieved_docs = state["retrieved_documents"]
return [
Send(
"doc_verification",
DocVerificationInput(doc_to_verify=doc, **state),
)
for doc in retrieved_docs
]
# this is not correct - remove
# def conditionally_rerank_edge(state: ExpandedRetrievalState) -> bool:
# print(f"conditionally_rerank_edge state: {state.keys()}")
# return bool(state["search_request"].rerank_settings)

View File

@@ -3,11 +3,13 @@ from langgraph.graph import START
from langgraph.graph import StateGraph
from danswer.agent_search.expanded_retrieval.edges import parallel_retrieval_edge
from danswer.agent_search.expanded_retrieval.edges import parallel_verification_edge
from danswer.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking
from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval
from danswer.agent_search.expanded_retrieval.nodes.doc_verification import (
doc_verification,
)
from danswer.agent_search.expanded_retrieval.nodes.dummy_node import dummy_node
from danswer.agent_search.expanded_retrieval.nodes.verification_kickoff import (
verification_kickoff,
)
@@ -15,6 +17,8 @@ from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInpu
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
# from danswer.agent_search.expanded_retrieval.edges import conditionally_rerank_edge
def expanded_retrieval_graph_builder() -> StateGraph:
graph = StateGraph(
@@ -42,6 +46,16 @@ def expanded_retrieval_graph_builder() -> StateGraph:
action=doc_reranking,
)
graph.add_node(
node="post_retrieval_dummy_node",
action=dummy_node,
)
graph.add_node(
node="dummy_node",
action=dummy_node,
)
### Add edges ###
graph.add_conditional_edges(
@@ -49,16 +63,43 @@ def expanded_retrieval_graph_builder() -> StateGraph:
path=parallel_retrieval_edge,
path_map=["doc_retrieval"],
)
graph.add_edge(
start_key="doc_retrieval",
end_key="verification_kickoff",
)
graph.add_conditional_edges(
source="verification_kickoff",
path=parallel_verification_edge,
path_map=["doc_verification"],
)
# graph.add_edge(
# start_key="doc_verification",
# end_key="post_retrieval_dummy_node",
# )
graph.add_edge(
start_key="doc_verification",
end_key="doc_reranking",
)
graph.add_edge(
start_key="doc_reranking",
end_key="dummy_node",
)
# graph.add_conditional_edges(
# source="doc_verification",
# path=conditionally_rerank_edge,
# path_map={
# True: "doc_reranking",
# False: END,
# },
# )
graph.add_edge(
start_key="dummy_node",
end_key=END,
)

View File

@@ -1,9 +1,11 @@
import datetime
from danswer.agent_search.expanded_retrieval.states import DocRerankingOutput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput:
print(f"doc_reranking state: {state.keys()}")
print(f"doc_reranking state: {datetime.datetime.now()}")
verified_documents = state["verified_documents"]
reranked_documents = verified_documents

View File

@@ -1,9 +1,10 @@
import datetime
from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
from danswer.context.search.models import InferenceSection
from danswer.context.search.models import SearchRequest
from danswer.context.search.pipeline import SearchPipeline
from danswer.db.engine import get_session_context_manager
class RetrieveInput(ExpandedRetrievalState):
@@ -21,27 +22,54 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput:
Returns:
state (dict): New key added to state, documents, that contains retrieved documents
"""
print(f"doc_retrieval state: {state.keys()}")
# print(f"doc_retrieval state: {state.keys()}")
state["query_to_retrieve"]
if "query_to_answer" in state.keys():
query_question = state["query_to_answer"]
else:
query_question = state["search_request"].query
query_to_retrieve = state["query_to_retrieve"]
print(f"\ndoc_retrieval state: {datetime.datetime.now()}")
print(f" -- search_request: {query_question[:100]}")
# print(f" -- query_to_retrieve: {query_to_retrieve[:100]}")
documents: list[InferenceSection] = []
llm = state["primary_llm"]
fast_llm = state["fast_llm"]
# db_session = state["db_session"]
query_to_retrieve = state["search_request"].query
with get_session_context_manager() as db_session1:
documents = SearchPipeline(
search_request=SearchRequest(
query=query_to_retrieve,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=db_session1,
).reranked_sections
documents = SearchPipeline(
search_request=SearchRequest(
query=query_to_retrieve,
),
user=None,
llm=llm,
fast_llm=fast_llm,
db_session=state["db_session"],
).reranked_sections
top_1_score = documents[0].center_chunk.score
top_5_score = sum([doc.center_chunk.score for doc in documents[:5]]) / 5
top_10_score = sum([doc.center_chunk.score for doc in documents[:10]]) / 10
fit_score = 1 / 3 * (top_1_score + top_5_score + top_10_score)
# temp - limit the number of documents to 5
documents = documents[:5]
"""
chunk_ids = {
"query": query_to_retrieve,
"chunk_ids": [doc.center_chunk.chunk_id for doc in documents],
}
"""
print(f"sub_query: {query_to_retrieve[:50]}")
print(f"retrieved documents: {len(documents)}")
print(f"fit score: {fit_score}")
print()
return DocRetrievalOutput(
retrieved_documents=documents,
)

View File

@@ -1,18 +1,15 @@
import datetime
from langchain_core.messages import HumanMessage
from langchain_core.messages import merge_message_runs
from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput
from danswer.agent_search.expanded_retrieval.states import DocVerificationOutput
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
from danswer.context.search.models import InferenceSection
class DocVerificationInput(ExpandedRetrievalState, total=True):
doc_to_verify: InferenceSection
def doc_verification(state: DocVerificationInput) -> DocVerificationOutput:
def doc_verification(state: DocRetrievalOutput) -> DocVerificationOutput:
"""
Check whether the document is relevant for the original user question
@@ -23,16 +20,20 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationOutput:
dict: ict: The updated state with the final decision
"""
print(f"doc_verification state: {state.keys()}")
# print(f"--- doc_verification state ---")
if "query_to_answer" in state.keys():
query_to_answer = state["query_to_answer"]
else:
query_to_answer = state["search_request"].query
original_query = state["search_request"].query
doc_to_verify = state["doc_to_verify"]
document_content = doc_to_verify.combined_content
msg = [
HumanMessage(
content=VERIFIER_PROMPT.format(
question=original_query, document_content=document_content
question=query_to_answer, document_content=document_content
)
)
]
@@ -49,12 +50,14 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationOutput:
decision_dict = {"decision": response_string.lower()}
formatted_response = BinaryDecision.model_validate(decision_dict)
print(f"Verdict: {formatted_response.decision}")
verified_documents = []
if formatted_response.decision == "yes":
verified_documents.append(doc_to_verify)
print(
f"Verdict & Completion: {formatted_response.decision} -- {datetime.datetime.now()}"
)
return DocVerificationOutput(
verified_documents=verified_documents,
)

View File

@@ -0,0 +1,9 @@
def dummy_node(state):
"""
This node is a dummy node that does not change the state but allows to inspect the state.
"""
print(f"doc_reranking state: {state.keys()}")
state["verified_documents"]
return {}

View File

@@ -1,9 +1,10 @@
import datetime
from typing import Literal
from langgraph.types import Command
from langgraph.types import Send
from danswer.agent_search.expanded_retrieval.nodes.doc_verification import (
from danswer.agent_search.expanded_retrieval.states import (
DocVerificationInput,
)
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
@@ -12,7 +13,7 @@ from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalStat
def verification_kickoff(
state: ExpandedRetrievalState,
) -> Command[Literal["doc_verification"]]:
print(f"verification_kickoff state: {state.keys()}")
print(f"verification_kickoff state: {datetime.datetime.now()}")
documents = state["retrieved_documents"]
return Command(

View File

@@ -8,6 +8,12 @@ from danswer.context.search.models import InferenceSection
class DocRetrievalOutput(TypedDict, total=False):
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
query_to_answer: str
class DocVerificationInput(TypedDict, total=True):
query_to_answer: str
doc_to_verify: InferenceSection
class DocVerificationOutput(TypedDict, total=False):
@@ -33,4 +39,4 @@ class ExpandedRetrievalInput(PrimaryState, total=True):
class ExpandedRetrievalOutput(TypedDict):
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]

View File

@@ -1,3 +1,5 @@
import datetime
from langgraph.graph import END
from langgraph.graph import START
from langgraph.graph import StateGraph
@@ -8,6 +10,7 @@ from danswer.agent_search.expanded_retrieval.graph_builder import (
)
from danswer.agent_search.main.edges import parallelize_decompozed_answer_queries
from danswer.agent_search.main.nodes.base_decomp import main_decomp_base
from danswer.agent_search.main.nodes.dummy_node import dummy_node
from danswer.agent_search.main.nodes.generate_initial_answer import (
generate_initial_answer,
)
@@ -23,6 +26,16 @@ def main_graph_builder() -> StateGraph:
### Add nodes ###
graph.add_node(
node="dummy_node_start",
action=dummy_node,
)
graph.add_node(
node="dummy_node_right",
action=dummy_node,
)
graph.add_node(
node="base_decomp",
action=main_decomp_base,
@@ -43,13 +56,27 @@ def main_graph_builder() -> StateGraph:
)
### Add edges ###
graph.add_edge(
start_key=START,
end_key="expanded_retrieval",
)
graph.add_edge(
start_key=START,
end_key="dummy_node_start",
)
graph.add_edge(
start_key="dummy_node_start",
end_key="dummy_node_right",
)
graph.add_edge(
start_key="dummy_node_right",
end_key="expanded_retrieval",
)
# graph.add_edge(
# start_key="expanded_retrieval",
# end_key="generate_initial_answer",
# )
graph.add_edge(
start_key="dummy_node_start",
end_key="base_decomp",
)
graph.add_conditional_edges(
@@ -78,7 +105,7 @@ if __name__ == "__main__":
compiled_graph = graph.compile()
primary_llm, fast_llm = get_default_llms()
search_request = SearchRequest(
query="If i am familiar with the function that I need, how can I type it into a cell?",
query="Who made Excel and what other products did they make?",
)
with get_session_context_manager() as db_session:
inputs = MainInput(
@@ -87,12 +114,12 @@ if __name__ == "__main__":
fast_llm=fast_llm,
db_session=db_session,
)
for thing in compiled_graph.stream(
print(f"START: {datetime.datetime.now()}")
output = compiled_graph.invoke(
input=inputs,
# stream_mode="debug",
# debug=True,
subgraphs=True,
):
# print(thing)
print()
print()
# subgraphs=True,
)
print(output)

View File

@@ -1,3 +1,5 @@
import datetime
from langchain_core.messages import HumanMessage
from danswer.agent_search.main.states import BaseDecompOutput
@@ -7,6 +9,7 @@ from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_s
def main_decomp_base(state: MainState) -> BaseDecompOutput:
print(f"main_decomp_base state: {datetime.datetime.now()}")
question = state["search_request"].query
msg = [
@@ -26,6 +29,7 @@ def main_decomp_base(state: MainState) -> BaseDecompOutput:
sub_question["sub_question"].strip() for sub_question in list_of_subquestions
]
print(f"Decomp Questions: {decomp_list}")
return BaseDecompOutput(
initial_decomp_queries=decomp_list,
)

View File

@@ -0,0 +1,10 @@
import datetime
def dummy_node(state):
"""
This node is a dummy node that does not change the state but allows to inspect the state.
"""
print(f"DUMMY NODE: {datetime.datetime.now()}")
return {}

View File

@@ -21,7 +21,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput:
"""
for decomp_answer_result in decomp_answer_results:
if (
decomp_answer_result.quality.lower() == "yes"
decomp_answer_result.quality == "yes"
and len(decomp_answer_result.answer) > 0
and decomp_answer_result.answer != "I don't know"
):
@@ -47,7 +47,5 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput:
# Grader
model = state["fast_llm"]
response = model.invoke(msg)
answer = response.pretty_repr()
print(answer)
return InitialAnswerOutput(initial_answer=answer)
return InitialAnswerOutput(initial_answer=response.pretty_repr())