mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
2 Commits
initial-im
...
initial_im
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f83ae16b8 | ||
|
|
ba54097069 |
@@ -9,7 +9,7 @@ from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
def answer_generation(state: AnswerQueryState) -> QAGenerationOutput:
|
||||
query = state["query_to_answer"]
|
||||
docs = state["reordered_documents"]
|
||||
docs = state["reranked_documents"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ def format_answer(state: AnswerQueryState) -> AnswerQueryOutput:
|
||||
query=state["query_to_answer"],
|
||||
quality=state["answer_quality"],
|
||||
answer=state["answer"],
|
||||
documents=state["reordered_documents"],
|
||||
documents=state["reranked_documents"],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ class QAGenerationOutput(TypedDict, total=False):
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(TypedDict):
|
||||
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class AnswerQueryState(
|
||||
@@ -35,6 +35,9 @@ class AnswerQueryState(
|
||||
total=True,
|
||||
):
|
||||
query_to_answer: str
|
||||
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class AnswerQueryInput(PrimaryState, total=True):
|
||||
|
||||
@@ -5,23 +5,36 @@ from langchain_core.messages import merge_message_runs
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput
|
||||
from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import DocVerificationInput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
|
||||
from danswer.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]:
|
||||
print(f"parallel_retrieval_edge state: {state.keys()}")
|
||||
# print(f"parallel_retrieval_edge state: {state.keys()}")
|
||||
print("parallel_retrieval_edge state")
|
||||
|
||||
# This should be better...
|
||||
question = state.get("query_to_answer") or state["search_request"].query
|
||||
llm: LLM = state["fast_llm"]
|
||||
|
||||
"""
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
"""
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
llm_response_list = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
@@ -29,9 +42,14 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
print(f"llm_response: {llm_response}")
|
||||
# print(f"llm_response: {llm_response}")
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
rewritten_queries = [
|
||||
rewritten_query.strip() for rewritten_query in llm_response.split("--")
|
||||
]
|
||||
|
||||
# Add the original sub-question as one of the 'rewritten' queries
|
||||
rewritten_queries = [question] + rewritten_queries
|
||||
|
||||
print(f"rewritten_queries: {rewritten_queries}")
|
||||
|
||||
@@ -42,3 +60,24 @@ def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashab
|
||||
)
|
||||
for query in rewritten_queries
|
||||
]
|
||||
|
||||
|
||||
def parallel_verification_edge(state: DocRetrievalOutput) -> list[Send | Hashable]:
|
||||
# print(f"parallel_retrieval_edge state: {state.keys()}")
|
||||
print("parallel_retrieval_edge state")
|
||||
|
||||
retrieved_docs = state["retrieved_documents"]
|
||||
|
||||
return [
|
||||
Send(
|
||||
"doc_verification",
|
||||
DocVerificationInput(doc_to_verify=doc, **state),
|
||||
)
|
||||
for doc in retrieved_docs
|
||||
]
|
||||
|
||||
|
||||
# this is not correct - remove
|
||||
# def conditionally_rerank_edge(state: ExpandedRetrievalState) -> bool:
|
||||
# print(f"conditionally_rerank_edge state: {state.keys()}")
|
||||
# return bool(state["search_request"].rerank_settings)
|
||||
|
||||
@@ -3,11 +3,13 @@ from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.edges import parallel_retrieval_edge
|
||||
from danswer.agent_search.expanded_retrieval.edges import parallel_verification_edge
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_verification import (
|
||||
doc_verification,
|
||||
)
|
||||
from danswer.agent_search.expanded_retrieval.nodes.dummy_node import dummy_node
|
||||
from danswer.agent_search.expanded_retrieval.nodes.verification_kickoff import (
|
||||
verification_kickoff,
|
||||
)
|
||||
@@ -15,6 +17,8 @@ from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInpu
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
|
||||
# from danswer.agent_search.expanded_retrieval.edges import conditionally_rerank_edge
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
@@ -42,6 +46,16 @@ def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
action=doc_reranking,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="post_retrieval_dummy_node",
|
||||
action=dummy_node,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="dummy_node",
|
||||
action=dummy_node,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
@@ -49,16 +63,43 @@ def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="doc_retrieval",
|
||||
end_key="verification_kickoff",
|
||||
)
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source="verification_kickoff",
|
||||
path=parallel_verification_edge,
|
||||
path_map=["doc_verification"],
|
||||
)
|
||||
|
||||
# graph.add_edge(
|
||||
# start_key="doc_verification",
|
||||
# end_key="post_retrieval_dummy_node",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="doc_verification",
|
||||
end_key="doc_reranking",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="doc_reranking",
|
||||
end_key="dummy_node",
|
||||
)
|
||||
|
||||
# graph.add_conditional_edges(
|
||||
# source="doc_verification",
|
||||
# path=conditionally_rerank_edge,
|
||||
# path_map={
|
||||
# True: "doc_reranking",
|
||||
# False: END,
|
||||
# },
|
||||
# )
|
||||
graph.add_edge(
|
||||
start_key="dummy_node",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
import datetime
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.states import DocRerankingOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
|
||||
|
||||
def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput:
|
||||
print(f"doc_reranking state: {state.keys()}")
|
||||
print(f"doc_reranking state: {datetime.datetime.now()}")
|
||||
|
||||
verified_documents = state["verified_documents"]
|
||||
reranked_documents = verified_documents
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import datetime
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
|
||||
|
||||
class RetrieveInput(ExpandedRetrievalState):
|
||||
@@ -21,27 +22,54 @@ def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput:
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print(f"doc_retrieval state: {state.keys()}")
|
||||
# print(f"doc_retrieval state: {state.keys()}")
|
||||
|
||||
state["query_to_retrieve"]
|
||||
if "query_to_answer" in state.keys():
|
||||
query_question = state["query_to_answer"]
|
||||
else:
|
||||
query_question = state["search_request"].query
|
||||
|
||||
query_to_retrieve = state["query_to_retrieve"]
|
||||
|
||||
print(f"\ndoc_retrieval state: {datetime.datetime.now()}")
|
||||
print(f" -- search_request: {query_question[:100]}")
|
||||
# print(f" -- query_to_retrieve: {query_to_retrieve[:100]}")
|
||||
|
||||
documents: list[InferenceSection] = []
|
||||
llm = state["primary_llm"]
|
||||
fast_llm = state["fast_llm"]
|
||||
# db_session = state["db_session"]
|
||||
query_to_retrieve = state["search_request"].query
|
||||
with get_session_context_manager() as db_session1:
|
||||
documents = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query_to_retrieve,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session1,
|
||||
).reranked_sections
|
||||
|
||||
documents = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query_to_retrieve,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=state["db_session"],
|
||||
).reranked_sections
|
||||
|
||||
top_1_score = documents[0].center_chunk.score
|
||||
top_5_score = sum([doc.center_chunk.score for doc in documents[:5]]) / 5
|
||||
top_10_score = sum([doc.center_chunk.score for doc in documents[:10]]) / 10
|
||||
|
||||
fit_score = 1 / 3 * (top_1_score + top_5_score + top_10_score)
|
||||
|
||||
# temp - limit the number of documents to 5
|
||||
documents = documents[:5]
|
||||
|
||||
"""
|
||||
chunk_ids = {
|
||||
"query": query_to_retrieve,
|
||||
"chunk_ids": [doc.center_chunk.chunk_id for doc in documents],
|
||||
}
|
||||
"""
|
||||
|
||||
print(f"sub_query: {query_to_retrieve[:50]}")
|
||||
print(f"retrieved documents: {len(documents)}")
|
||||
print(f"fit score: {fit_score}")
|
||||
print()
|
||||
return DocRetrievalOutput(
|
||||
retrieved_documents=documents,
|
||||
)
|
||||
|
||||
@@ -1,18 +1,15 @@
|
||||
import datetime
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import DocVerificationOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalState, total=True):
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
def doc_verification(state: DocVerificationInput) -> DocVerificationOutput:
|
||||
def doc_verification(state: DocRetrievalOutput) -> DocVerificationOutput:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
@@ -23,16 +20,20 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationOutput:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print(f"doc_verification state: {state.keys()}")
|
||||
# print(f"--- doc_verification state ---")
|
||||
|
||||
if "query_to_answer" in state.keys():
|
||||
query_to_answer = state["query_to_answer"]
|
||||
else:
|
||||
query_to_answer = state["search_request"].query
|
||||
|
||||
original_query = state["search_request"].query
|
||||
doc_to_verify = state["doc_to_verify"]
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=original_query, document_content=document_content
|
||||
question=query_to_answer, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
@@ -49,12 +50,14 @@ def doc_verification(state: DocVerificationInput) -> DocVerificationOutput:
|
||||
decision_dict = {"decision": response_string.lower()}
|
||||
formatted_response = BinaryDecision.model_validate(decision_dict)
|
||||
|
||||
print(f"Verdict: {formatted_response.decision}")
|
||||
|
||||
verified_documents = []
|
||||
if formatted_response.decision == "yes":
|
||||
verified_documents.append(doc_to_verify)
|
||||
|
||||
print(
|
||||
f"Verdict & Completion: {formatted_response.decision} -- {datetime.datetime.now()}"
|
||||
)
|
||||
|
||||
return DocVerificationOutput(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,9 @@
|
||||
def dummy_node(state):
|
||||
"""
|
||||
This node is a dummy node that does not change the state but allows to inspect the state.
|
||||
"""
|
||||
print(f"doc_reranking state: {state.keys()}")
|
||||
|
||||
state["verified_documents"]
|
||||
|
||||
return {}
|
||||
@@ -1,9 +1,10 @@
|
||||
import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_verification import (
|
||||
from danswer.agent_search.expanded_retrieval.states import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
@@ -12,7 +13,7 @@ from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalStat
|
||||
def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
print(f"verification_kickoff state: {state.keys()}")
|
||||
print(f"verification_kickoff state: {datetime.datetime.now()}")
|
||||
|
||||
documents = state["retrieved_documents"]
|
||||
return Command(
|
||||
|
||||
@@ -8,6 +8,12 @@ from danswer.context.search.models import InferenceSection
|
||||
|
||||
class DocRetrievalOutput(TypedDict, total=False):
|
||||
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
query_to_answer: str
|
||||
|
||||
|
||||
class DocVerificationInput(TypedDict, total=True):
|
||||
query_to_answer: str
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
class DocVerificationOutput(TypedDict, total=False):
|
||||
@@ -33,4 +39,4 @@ class ExpandedRetrievalInput(PrimaryState, total=True):
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(TypedDict):
|
||||
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import datetime
|
||||
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
@@ -8,6 +10,7 @@ from danswer.agent_search.expanded_retrieval.graph_builder import (
|
||||
)
|
||||
from danswer.agent_search.main.edges import parallelize_decompozed_answer_queries
|
||||
from danswer.agent_search.main.nodes.base_decomp import main_decomp_base
|
||||
from danswer.agent_search.main.nodes.dummy_node import dummy_node
|
||||
from danswer.agent_search.main.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
@@ -23,6 +26,16 @@ def main_graph_builder() -> StateGraph:
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="dummy_node_start",
|
||||
action=dummy_node,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="dummy_node_right",
|
||||
action=dummy_node,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="base_decomp",
|
||||
action=main_decomp_base,
|
||||
@@ -43,13 +56,27 @@ def main_graph_builder() -> StateGraph:
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expanded_retrieval",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="dummy_node_start",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="dummy_node_start",
|
||||
end_key="dummy_node_right",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="dummy_node_right",
|
||||
end_key="expanded_retrieval",
|
||||
)
|
||||
# graph.add_edge(
|
||||
# start_key="expanded_retrieval",
|
||||
# end_key="generate_initial_answer",
|
||||
# )
|
||||
|
||||
graph.add_edge(
|
||||
start_key="dummy_node_start",
|
||||
end_key="base_decomp",
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
@@ -78,7 +105,7 @@ if __name__ == "__main__":
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="If i am familiar with the function that I need, how can I type it into a cell?",
|
||||
query="Who made Excel and what other products did they make?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = MainInput(
|
||||
@@ -87,12 +114,12 @@ if __name__ == "__main__":
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
|
||||
print(f"START: {datetime.datetime.now()}")
|
||||
|
||||
output = compiled_graph.invoke(
|
||||
input=inputs,
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
# print(thing)
|
||||
print()
|
||||
print()
|
||||
# subgraphs=True,
|
||||
)
|
||||
print(output)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import datetime
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.main.states import BaseDecompOutput
|
||||
@@ -7,6 +9,7 @@ from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_s
|
||||
|
||||
|
||||
def main_decomp_base(state: MainState) -> BaseDecompOutput:
|
||||
print(f"main_decomp_base state: {datetime.datetime.now()}")
|
||||
question = state["search_request"].query
|
||||
|
||||
msg = [
|
||||
@@ -26,6 +29,7 @@ def main_decomp_base(state: MainState) -> BaseDecompOutput:
|
||||
sub_question["sub_question"].strip() for sub_question in list_of_subquestions
|
||||
]
|
||||
|
||||
print(f"Decomp Questions: {decomp_list}")
|
||||
return BaseDecompOutput(
|
||||
initial_decomp_queries=decomp_list,
|
||||
)
|
||||
|
||||
10
backend/danswer/agent_search/main/nodes/dummy_node.py
Normal file
10
backend/danswer/agent_search/main/nodes/dummy_node.py
Normal file
@@ -0,0 +1,10 @@
|
||||
import datetime
|
||||
|
||||
|
||||
def dummy_node(state):
|
||||
"""
|
||||
This node is a dummy node that does not change the state but allows to inspect the state.
|
||||
"""
|
||||
print(f"DUMMY NODE: {datetime.datetime.now()}")
|
||||
|
||||
return {}
|
||||
@@ -21,7 +21,7 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput:
|
||||
"""
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
if (
|
||||
decomp_answer_result.quality.lower() == "yes"
|
||||
decomp_answer_result.quality == "yes"
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != "I don't know"
|
||||
):
|
||||
@@ -47,7 +47,5 @@ def generate_initial_answer(state: MainState) -> InitialAnswerOutput:
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
print(answer)
|
||||
return InitialAnswerOutput(initial_answer=answer)
|
||||
return InitialAnswerOutput(initial_answer=response.pretty_repr())
|
||||
|
||||
Reference in New Issue
Block a user