mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
31 Commits
initial_im
...
more-conf-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4f96954bcf | ||
|
|
6219f311bb | ||
|
|
b25c10a51a | ||
|
|
1bcfa28fda | ||
|
|
7a0d823c89 | ||
|
|
db69e445d6 | ||
|
|
18e63889b7 | ||
|
|
738e60c8ed | ||
|
|
8aec873e66 | ||
|
|
7c57dde8ab | ||
|
|
f30adab853 | ||
|
|
601687a522 | ||
|
|
350cf407c9 | ||
|
|
32ec4efc7a | ||
|
|
7c6981e052 | ||
|
|
c50cd20156 | ||
|
|
14772dee71 | ||
|
|
c81e704c95 | ||
|
|
3266ef6321 | ||
|
|
c89b98b4f2 | ||
|
|
e70e0ab859 | ||
|
|
69b6e9321e | ||
|
|
7e53af18b6 | ||
|
|
b9eb1ca2ba | ||
|
|
91d44c83d2 | ||
|
|
4dbc6bb4d1 | ||
|
|
4b6a4c6bbf | ||
|
|
fd1999454a | ||
|
|
0a35422d1d | ||
|
|
69b99056b2 | ||
|
|
2a55696545 |
@@ -1,42 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
|
||||
|
||||
def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes each de-douped retrieved doc to the verifier step - in parallel
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
|
||||
return [
|
||||
Send(
|
||||
"sub_verifier",
|
||||
VerifierState(
|
||||
document=doc,
|
||||
#question=state["original_question"],
|
||||
question=state["sub_question_str"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for doc in state["sub_question_deduped_retrieval_docs"]
|
||||
]
|
||||
|
||||
|
||||
def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
rewritten_queries = state["sub_question_search_queries"].rewritten_queries + [state["sub_question_str"]]
|
||||
return [
|
||||
Send(
|
||||
"sub_custom_retrieve",
|
||||
RetrieverState(
|
||||
rewritten_query=query,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for query in rewritten_queries
|
||||
]
|
||||
@@ -1,132 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_retrieval
|
||||
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_verifier
|
||||
from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import (
|
||||
sub_combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import (
|
||||
sub_custom_retrieve,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.dummy import sub_dummy
|
||||
from danswer.agent_search.core_qa_graph.nodes.final_format import (
|
||||
sub_final_format,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.generate import sub_generate
|
||||
from danswer.agent_search.core_qa_graph.nodes.qa_check import sub_qa_check
|
||||
from danswer.agent_search.core_qa_graph.nodes.rewrite import sub_rewrite
|
||||
from danswer.agent_search.core_qa_graph.nodes.verifier import sub_verifier
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAOutputState
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.core_qa_graph.states import CoreQAInputState
|
||||
|
||||
|
||||
def build_core_qa_graph() -> StateGraph:
|
||||
sub_answers_initial = StateGraph(
|
||||
state_schema=BaseQAState,
|
||||
output=BaseQAOutputState,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy)
|
||||
sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_custom_retrieve",
|
||||
action=sub_custom_retrieve,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_combine_retrieved_docs",
|
||||
action=sub_combine_retrieved_docs,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_verifier",
|
||||
action=sub_verifier,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_generate",
|
||||
action=sub_generate,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_qa_check",
|
||||
action=sub_qa_check,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_final_format",
|
||||
action=sub_final_format,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
sub_answers_initial.add_edge(START, "sub_dummy")
|
||||
sub_answers_initial.add_edge("sub_dummy", "sub_rewrite")
|
||||
|
||||
sub_answers_initial.add_conditional_edges(
|
||||
source="sub_rewrite",
|
||||
path=sub_continue_to_retrieval,
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_custom_retrieve",
|
||||
end_key="sub_combine_retrieved_docs",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_conditional_edges(
|
||||
source="sub_combine_retrieved_docs",
|
||||
path=sub_continue_to_verifier,
|
||||
path_map=["sub_verifier"],
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_verifier",
|
||||
end_key="sub_generate",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_generate",
|
||||
end_key="sub_qa_check",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_qa_check",
|
||||
end_key="sub_final_format",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_final_format",
|
||||
end_key=END,
|
||||
)
|
||||
# sub_answers_graph = sub_answers_initial.compile()
|
||||
return sub_answers_initial
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# q = "Whose music is kind of hard to easily enjoy?"
|
||||
# q = "What is voice leading?"
|
||||
# q = "What are the types of motions in music?"
|
||||
# q = "What are key elements of music theory?"
|
||||
# q = "How can I best understand music theory using voice leading?"
|
||||
q = "What makes good music?"
|
||||
# q = "types of motions in music"
|
||||
# q = "What is the relationship between music and physics?"
|
||||
# q = "Can you compare various grunge styles?"
|
||||
# q = "Why is quantum gravity so hard?"
|
||||
|
||||
inputs = CoreQAInputState(
|
||||
original_question=q,
|
||||
sub_question_str=q,
|
||||
)
|
||||
sub_answers_graph = build_core_qa_graph()
|
||||
compiled_sub_answers = sub_answers_graph.compile()
|
||||
output = compiled_sub_answers.invoke(inputs)
|
||||
print("\nOUTPUT:")
|
||||
print(output.keys())
|
||||
for key, value in output.items():
|
||||
if key in [
|
||||
"sub_question_answer",
|
||||
"sub_question_str",
|
||||
"sub_qas",
|
||||
"initial_sub_qas",
|
||||
"sub_question_answer",
|
||||
]:
|
||||
print(f"{key}: {value}")
|
||||
@@ -1,36 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
|
||||
dedupe_docs: list[InferenceSection] = []
|
||||
for base_retrieval_doc in sub_question_base_retrieval_docs:
|
||||
if not any(
|
||||
base_retrieval_doc.center_chunk.chunk_id == doc.center_chunk.chunk_id
|
||||
for doc in dedupe_docs
|
||||
):
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
|
||||
return {
|
||||
"sub_question_deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE SUB---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
rewritten_query = state["rewritten_query"]
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
documents: list[InferenceSection] = []
|
||||
llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=rewritten_query,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
reranked_docs = documents.reranked_sections
|
||||
|
||||
# initial metric to measure fit TODO: implement metric properly
|
||||
|
||||
top_1_score = reranked_docs[0].center_chunk.score
|
||||
top_5_score = sum([doc.center_chunk.score for doc in reranked_docs[:5]]) / 5
|
||||
top_10_score = sum([doc.center_chunk.score for doc in reranked_docs[:10]]) / 10
|
||||
|
||||
fit_score = 1/3 * (top_1_score + top_5_score + top_10_score)
|
||||
|
||||
chunk_ids = {'query': rewritten_query,
|
||||
'chunk_ids': [doc.center_chunk.chunk_id for doc in reranked_docs]}
|
||||
|
||||
|
||||
return {
|
||||
"sub_question_base_retrieval_docs": reranked_docs,
|
||||
"sub_chunk_ids": [chunk_ids],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - custom_retrieve, fit_score: {fit_score}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy step
|
||||
"""
|
||||
|
||||
print("---Sub Dummy---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
return {
|
||||
"graph_start_time": node_start_time,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - dummy",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=node_start_time,
|
||||
),
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
|
||||
|
||||
def sub_final_format(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Create the final output for the QA subgraph
|
||||
"""
|
||||
|
||||
print("---BASE FINAL FORMAT---")
|
||||
|
||||
return {
|
||||
"sub_qas": [
|
||||
{
|
||||
"sub_question": state["sub_question_str"],
|
||||
"sub_answer": state["sub_question_answer"],
|
||||
"sub_answer_check": state["sub_question_answer_check"],
|
||||
}
|
||||
],
|
||||
"log_messages": state["log_messages"],
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_generate(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
|
||||
# Create sub-query results
|
||||
|
||||
verified_chunks = [chunk.center_chunk.chunk_id for chunk in state["sub_question_verified_retrieval_docs"]]
|
||||
result_dict = {}
|
||||
|
||||
chunk_id_dicts = state["sub_chunk_ids"]
|
||||
expanded_chunks = []
|
||||
original_chunks = []
|
||||
|
||||
for chunk_id_dict in chunk_id_dicts:
|
||||
sub_question = chunk_id_dict['query']
|
||||
verified_sq_chunks = [chunk_id for chunk_id in chunk_id_dict['chunk_ids'] if chunk_id in verified_chunks]
|
||||
|
||||
if sub_question != state["original_question"]:
|
||||
expanded_chunks += verified_sq_chunks
|
||||
else:
|
||||
result_dict['ORIGINAL'] = len(verified_sq_chunks)
|
||||
original_chunks += verified_sq_chunks
|
||||
result_dict[sub_question[:30]] = len(verified_sq_chunks)
|
||||
|
||||
expansion_chunks = set(expanded_chunks)
|
||||
num_expansion_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id in verified_chunks])
|
||||
num_original_relevant_chunks = len(original_chunks)
|
||||
num_missed_relevant_chunks = sum([1 for chunk_id in original_chunks if chunk_id not in expansion_chunks])
|
||||
num_gained_relevant_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id not in original_chunks])
|
||||
result_dict['expansion_chunks'] = num_expansion_chunks
|
||||
|
||||
|
||||
|
||||
print(result_dict)
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question_str"]
|
||||
docs = state["sub_question_verified_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
# Only take the top 10 docs.
|
||||
# TODO: Make this dynamic or use config param?
|
||||
top_10_docs = docs[-10:]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(top_10_docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
_, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
return {
|
||||
"sub_question_answer": answer_str,
|
||||
"log_messages": generate_log_message(
|
||||
message="base - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_qa_check(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Check if the sub-question answer is satisfactory.
|
||||
|
||||
Args:
|
||||
state: The current SubQAState containing the sub-question and its answer
|
||||
|
||||
Returns:
|
||||
dict containing the check result and log message
|
||||
"""
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_CHECK_PROMPT.format(
|
||||
question=state["sub_question_str"],
|
||||
base_answer=state["sub_question_answer"],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
_, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
response_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
return {
|
||||
"sub_question_answer_check": response_str,
|
||||
"base_answer_messages": generate_log_message(
|
||||
message="sub - qa_check",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_rewrite(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
|
||||
print("---SUB TRANSFORM QUERY---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
# messages = state["base_answer_messages"]
|
||||
question = state["sub_question_str"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
"""
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
_, fast_llm = get_default_llms()
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
|
||||
# structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
print(f"llm_response: {llm_response}")
|
||||
|
||||
rewritten_queries = llm_response.split("--")
|
||||
# rewritten_queries = [llm_response.split("\n")[0]]
|
||||
|
||||
print(f"rewritten_queries: {rewritten_queries}")
|
||||
|
||||
rewritten_queries = RewrittenQueries(rewritten_queries=rewritten_queries)
|
||||
|
||||
return {
|
||||
"sub_question_search_queries": rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
# print("---VERIFY QUTPUT---")
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
response_string = merge_message_runs(response, chunk_separator="")[0].content
|
||||
# Convert string response to proper dictionary format
|
||||
decision_dict = {"decision": response_string.lower()}
|
||||
formatted_response = BinaryDecision.model_validate(decision_dict)
|
||||
|
||||
print(f"Verification end time: {datetime.datetime.now()}")
|
||||
|
||||
return {
|
||||
"sub_question_verified_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
class SubQuestionRetrieverState(TypedDict):
|
||||
# The state for the parallel Retrievers. They each need to see only one query
|
||||
sub_question_rewritten_query: str
|
||||
|
||||
|
||||
class SubQuestionVerifierState(TypedDict):
|
||||
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
|
||||
sub_question_document: InferenceSection
|
||||
sub_question: str
|
||||
|
||||
|
||||
class CoreQAInputState(TypedDict):
|
||||
sub_question_str: str
|
||||
original_question: str
|
||||
|
||||
|
||||
class BaseQAState(TypedDict):
|
||||
# The 'core SubQuestion' state.
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
# start time for parallel initial sub-questionn thread
|
||||
sub_query_start_time: datetime
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question_str: str
|
||||
sub_question_search_queries: RewrittenQueries
|
||||
sub_question_nr: int
|
||||
sub_chunk_ids: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
# Answers sent back to core
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
|
||||
class BaseQAOutputState(TypedDict):
|
||||
# The 'SubQuestion' output state. Removes all the intermediate states
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question_str: str
|
||||
sub_question_search_queries: list[str]
|
||||
sub_question_nr: int
|
||||
# Answers sent back to core
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
# Answers sent back to core
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
@@ -1,46 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
|
||||
|
||||
def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes each de-douped retrieved doc to the verifier step - in parallel
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
|
||||
return [
|
||||
Send(
|
||||
"sub_verifier",
|
||||
VerifierState(
|
||||
document=doc,
|
||||
question=state["sub_question"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for doc in state["sub_question_base_retrieval_docs"]
|
||||
]
|
||||
|
||||
|
||||
def sub_continue_to_retrieval(
|
||||
state: ResearchQAState,
|
||||
) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_custom_retrieve",
|
||||
RetrieverState(
|
||||
rewritten_query=query,
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for query in state["sub_question_rewritten_queries"]
|
||||
]
|
||||
@@ -1,93 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_retrieval
|
||||
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_verifier
|
||||
from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import (
|
||||
sub_combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import sub_custom_retrieve
|
||||
from danswer.agent_search.deep_qa_graph.nodes.dummy import sub_dummy
|
||||
from danswer.agent_search.deep_qa_graph.nodes.final_format import sub_final_format
|
||||
from danswer.agent_search.deep_qa_graph.nodes.generate import sub_generate
|
||||
from danswer.agent_search.deep_qa_graph.nodes.qa_check import sub_qa_check
|
||||
from danswer.agent_search.deep_qa_graph.nodes.verifier import sub_verifier
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
|
||||
|
||||
def build_deep_qa_graph() -> StateGraph:
|
||||
# Define the nodes we will cycle between
|
||||
sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState)
|
||||
|
||||
### Add Nodes ###
|
||||
|
||||
# Dummy node for initial processing
|
||||
sub_answers.add_node(node="sub_dummy", action=sub_dummy)
|
||||
|
||||
# The retrieval step
|
||||
sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve)
|
||||
|
||||
# The dedupe step
|
||||
sub_answers.add_node(
|
||||
node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs
|
||||
)
|
||||
|
||||
# Verifying retrieved information
|
||||
sub_answers.add_node(node="sub_verifier", action=sub_verifier)
|
||||
|
||||
# Generating the response
|
||||
sub_answers.add_node(node="sub_generate", action=sub_generate)
|
||||
|
||||
# Checking the quality of the answer
|
||||
sub_answers.add_node(node="sub_qa_check", action=sub_qa_check)
|
||||
|
||||
# Final formatting of the response
|
||||
sub_answers.add_node(node="sub_final_format", action=sub_final_format)
|
||||
|
||||
### Add Edges ###
|
||||
|
||||
# Generate multiple sub-questions
|
||||
sub_answers.add_edge(start_key=START, end_key="sub_rewrite")
|
||||
|
||||
# For each sub-question, perform a retrieval in parallel
|
||||
sub_answers.add_conditional_edges(
|
||||
source="sub_rewrite",
|
||||
path=sub_continue_to_retrieval,
|
||||
path_map=["sub_custom_retrieve"],
|
||||
)
|
||||
|
||||
# Combine the retrieved docs for each sub-question from the parallel retrievals
|
||||
sub_answers.add_edge(
|
||||
start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs"
|
||||
)
|
||||
|
||||
# Go over all of the combined retrieved docs and verify them against the original question
|
||||
sub_answers.add_conditional_edges(
|
||||
source="sub_combine_retrieved_docs",
|
||||
path=sub_continue_to_verifier,
|
||||
path_map=["sub_verifier"],
|
||||
)
|
||||
|
||||
# Generate an answer for each verified retrieved doc
|
||||
sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate")
|
||||
|
||||
# Check the quality of the answer
|
||||
sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check")
|
||||
|
||||
sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format")
|
||||
|
||||
sub_answers.add_edge(start_key="sub_final_format", end_key=END)
|
||||
|
||||
return sub_answers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: add the actual question
|
||||
inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"}
|
||||
sub_answers_graph = build_deep_qa_graph()
|
||||
compiled_sub_answers = sub_answers_graph.compile()
|
||||
output = compiled_sub_answers.invoke(inputs)
|
||||
print("\nOUTPUT:")
|
||||
print(output)
|
||||
@@ -1,31 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
|
||||
dedupe_docs = []
|
||||
for base_retrieval_doc in sub_question_base_retrieval_docs:
|
||||
if base_retrieval_doc not in dedupe_docs:
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
return {
|
||||
"sub_question_deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE SUB---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
documents: list[InferenceSection] = []
|
||||
|
||||
return {
|
||||
"sub_question_base_retrieval_docs": documents,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - custom_retrieve",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy step
|
||||
"""
|
||||
|
||||
print("---Sub Dummy---")
|
||||
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - dummy",
|
||||
node_start_time=datetime.now(),
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_final_format(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Create the final output for the QA subgraph
|
||||
"""
|
||||
|
||||
print("---SUB FINAL FORMAT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
return {
|
||||
# TODO: Type this
|
||||
"sub_qas": [
|
||||
{
|
||||
"sub_question": state["sub_question"],
|
||||
"sub_answer": state["sub_question_answer"],
|
||||
"sub_question_nr": state["sub_question_nr"],
|
||||
"sub_answer_check": state["sub_question_answer_check"],
|
||||
}
|
||||
],
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - final format",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_generate(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---SUB GENERATE---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question"]
|
||||
docs = state["sub_question_verified_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs for sub-question: {len(docs)}")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
if len(docs) > 0:
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
response_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
else:
|
||||
response_str = ""
|
||||
|
||||
return {
|
||||
"sub_question_answer": response_str,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.prompts import SUB_CHECK_PROMPT
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_qa_check(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the final output satisfies the original user question
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---CHECK SUB QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_answer = state["sub_question_answer"]
|
||||
sub_question = state["sub_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
sub_question=sub_question, sub_answer=sub_answer
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"sub_question_answer_check": formatted_response.decision,
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - qa check: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
def sub_rewrite(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
|
||||
print("---SUB TRANSFORM QUERY---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
fast_llm: LLM = state["fast_llm"]
|
||||
llm_response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
|
||||
|
||||
print(f"rewritten_queries: {rewritten_queries}")
|
||||
|
||||
rewritten_queries = RewrittenQueries(
|
||||
rewritten_queries=[
|
||||
"music hard to listen to",
|
||||
"Music that is not fun or pleasant",
|
||||
]
|
||||
)
|
||||
|
||||
print(f"hardcoded rewritten_queries: {rewritten_queries}")
|
||||
|
||||
return {
|
||||
"sub_question_rewritten_queries": rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---SUB VERIFY QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"core - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
SUB_CHECK_PROMPT = """ \n
|
||||
Please check whether the suggested answer seems to address the original question.
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
@@ -1,64 +0,0 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
class ResearchQAState(TypedDict):
|
||||
# The 'core SubQuestion' state.
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question: str
|
||||
sub_question_nr: int
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
|
||||
class ResearchQAOutputState(TypedDict):
|
||||
# The 'SubQuestion' output state. Removes all the intermediate states
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question: str
|
||||
sub_question_nr: int
|
||||
# Answers sent back to core
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
@@ -1,75 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
|
||||
|
||||
|
||||
def continue_to_initial_sub_questions(
|
||||
state: QAState,
|
||||
) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_answers_graph_initial",
|
||||
BaseQAState(
|
||||
sub_question_str=initial_sub_question["sub_question_str"],
|
||||
sub_question_search_queries=initial_sub_question[
|
||||
"sub_question_search_queries"
|
||||
],
|
||||
sub_question_nr=initial_sub_question["sub_question_nr"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for initial_sub_question in state["initial_sub_questions"]
|
||||
]
|
||||
|
||||
|
||||
def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_answers_graph",
|
||||
ResearchQAState(
|
||||
sub_question=sub_question["sub_question_str"],
|
||||
sub_question_nr=sub_question["sub_question_nr"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
),
|
||||
)
|
||||
for sub_question in state["sub_questions"]
|
||||
]
|
||||
|
||||
|
||||
def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
print("---GO TO DEEP ANSWER OR END---")
|
||||
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
BASE_CHECK_MESSAGE = [
|
||||
HumanMessage(
|
||||
content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
|
||||
)
|
||||
]
|
||||
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(BASE_CHECK_MESSAGE)
|
||||
|
||||
print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
|
||||
|
||||
if response.pretty_repr() == "no":
|
||||
return "decompose"
|
||||
else:
|
||||
return "end"
|
||||
@@ -1,171 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.core_qa_graph.graph_builder import build_core_qa_graph
|
||||
from danswer.agent_search.deep_qa_graph.graph_builder import build_deep_qa_graph
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_deep_answer
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions
|
||||
from danswer.agent_search.primary_graph.nodes.base_wait import base_wait
|
||||
from danswer.agent_search.primary_graph.nodes.combine_retrieved_docs import (
|
||||
combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.custom_retrieve import custom_retrieve
|
||||
from danswer.agent_search.primary_graph.nodes.decompose import decompose
|
||||
from danswer.agent_search.primary_graph.nodes.deep_answer_generation import (
|
||||
deep_answer_generation,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.dummy_start import dummy_start
|
||||
from danswer.agent_search.primary_graph.nodes.entity_term_extraction import (
|
||||
entity_term_extraction,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.final_stuff import final_stuff
|
||||
from danswer.agent_search.primary_graph.nodes.generate_initial import generate_initial
|
||||
from danswer.agent_search.primary_graph.nodes.main_decomp_base import main_decomp_base
|
||||
from danswer.agent_search.primary_graph.nodes.rewrite import rewrite
|
||||
from danswer.agent_search.primary_graph.nodes.sub_qa_level_aggregator import (
|
||||
sub_qa_level_aggregator,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.sub_qa_manager import sub_qa_manager
|
||||
from danswer.agent_search.primary_graph.nodes.verifier import verifier
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
|
||||
|
||||
def build_core_graph() -> StateGraph:
|
||||
# Define the nodes we will cycle between
|
||||
core_answer_graph = StateGraph(state_schema=QAState)
|
||||
|
||||
### Add Nodes ###
|
||||
core_answer_graph.add_node(node="dummy_start",
|
||||
action=dummy_start)
|
||||
|
||||
# Re-writing the question
|
||||
core_answer_graph.add_node(node="rewrite",
|
||||
action=rewrite)
|
||||
|
||||
# The retrieval step
|
||||
core_answer_graph.add_node(node="custom_retrieve",
|
||||
action=custom_retrieve)
|
||||
|
||||
# Combine and dedupe retrieved docs.
|
||||
core_answer_graph.add_node(
|
||||
node="combine_retrieved_docs",
|
||||
action=combine_retrieved_docs
|
||||
)
|
||||
|
||||
# Extract entities, terms and relationships
|
||||
core_answer_graph.add_node(
|
||||
node="entity_term_extraction",
|
||||
action=entity_term_extraction
|
||||
)
|
||||
|
||||
# Verifying that a retrieved doc is relevant
|
||||
core_answer_graph.add_node(node="verifier",
|
||||
action=verifier)
|
||||
|
||||
# Initial question decomposition
|
||||
core_answer_graph.add_node(node="main_decomp_base",
|
||||
action=main_decomp_base)
|
||||
|
||||
# Build the base QA sub-graph and compile it
|
||||
compiled_core_qa_graph = build_core_qa_graph().compile()
|
||||
# Add the compiled base QA sub-graph as a node to the core graph
|
||||
core_answer_graph.add_node(
|
||||
node="sub_answers_graph_initial",
|
||||
action=compiled_core_qa_graph
|
||||
)
|
||||
|
||||
# Checking whether the initial answer is in the ballpark
|
||||
core_answer_graph.add_node(node="base_wait",
|
||||
action=base_wait)
|
||||
|
||||
# Decompose the question into sub-questions
|
||||
core_answer_graph.add_node(node="decompose",
|
||||
action=decompose)
|
||||
|
||||
# Manage the sub-questions
|
||||
core_answer_graph.add_node(node="sub_qa_manager",
|
||||
action=sub_qa_manager)
|
||||
|
||||
# Build the research QA sub-graph and compile it
|
||||
compiled_deep_qa_graph = build_deep_qa_graph().compile()
|
||||
# Add the compiled research QA sub-graph as a node to the core graph
|
||||
core_answer_graph.add_node(node="sub_answers_graph",
|
||||
action=compiled_deep_qa_graph)
|
||||
|
||||
# Aggregate the sub-questions
|
||||
core_answer_graph.add_node(
|
||||
node="sub_qa_level_aggregator",
|
||||
action=sub_qa_level_aggregator
|
||||
)
|
||||
|
||||
# aggregate sub questions and answers
|
||||
core_answer_graph.add_node(
|
||||
node="deep_answer_generation",
|
||||
action=deep_answer_generation
|
||||
)
|
||||
|
||||
# A final clean-up step
|
||||
core_answer_graph.add_node(node="final_stuff",
|
||||
action=final_stuff)
|
||||
|
||||
# Generating a response after we know the documents are relevant
|
||||
core_answer_graph.add_node(node="generate_initial",
|
||||
action=generate_initial)
|
||||
|
||||
### Add Edges ###
|
||||
|
||||
# start the initial sub-question decomposition
|
||||
core_answer_graph.add_edge(start_key=START,
|
||||
end_key="main_decomp_base")
|
||||
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="main_decomp_base",
|
||||
path=continue_to_initial_sub_questions,
|
||||
)
|
||||
|
||||
# use the retrieved information to generate the answer
|
||||
core_answer_graph.add_edge(
|
||||
start_key=["verifier", "sub_answers_graph_initial"],
|
||||
end_key="generate_initial"
|
||||
)
|
||||
core_answer_graph.add_edge(start_key="generate_initial",
|
||||
end_key="base_wait")
|
||||
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="base_wait",
|
||||
path=continue_to_deep_answer,
|
||||
path_map={"decompose": "entity_term_extraction", "end": "final_stuff"},
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(start_key="entity_term_extraction", end_key="decompose")
|
||||
|
||||
core_answer_graph.add_edge(start_key="decompose",
|
||||
end_key="sub_qa_manager")
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="sub_qa_manager",
|
||||
path=continue_to_answer_sub_questions,
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="sub_answers_graph",
|
||||
end_key="sub_qa_level_aggregator"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="sub_qa_level_aggregator",
|
||||
end_key="deep_answer_generation"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="deep_answer_generation",
|
||||
end_key="final_stuff"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(start_key="final_stuff",
|
||||
end_key=END)
|
||||
|
||||
core_answer_graph.compile()
|
||||
|
||||
return core_answer_graph
|
||||
@@ -1,27 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def base_wait(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Ensures that all required steps are completed before proceeding to the next step
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: {} (no operation, just logging)
|
||||
"""
|
||||
|
||||
print("---Base Wait ---")
|
||||
node_start_time = datetime.now()
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="core - base_wait",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def combine_retrieved_docs(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
base_retrieval_docs: Sequence[InferenceSection] = state["base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(base_retrieval_docs)}")
|
||||
dedupe_docs: list[InferenceSection] = []
|
||||
for base_retrieval_doc in base_retrieval_docs:
|
||||
if not any(
|
||||
base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id
|
||||
for doc in dedupe_docs
|
||||
):
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
retriever_state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE---")
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
query = state["rewritten_query"]
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
top_sections = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
).reranked_sections
|
||||
print(len(top_sections))
|
||||
documents: list[InferenceSection] = []
|
||||
|
||||
return {
|
||||
"base_retrieval_docs": documents,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - custom_retrieve",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def decompose(state: QAState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_term_extraction_dict = state["retrieved_entities_relationships"][
|
||||
"retrieved_entities_relationships"
|
||||
]
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_term_extraction_dict
|
||||
)
|
||||
|
||||
initial_question_answers = state["initial_sub_qas"]
|
||||
|
||||
addressed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "yes"
|
||||
]
|
||||
failed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "no"
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT.format(
|
||||
question=question,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
sub_questions_dict = {}
|
||||
for sub_question_nr, sub_question_dict in enumerate(
|
||||
parsed_response["sub_questions"]
|
||||
):
|
||||
sub_question_dict["answered"] = False
|
||||
sub_question_dict["verified"] = False
|
||||
sub_questions_dict[sub_question_nr] = sub_question_dict
|
||||
|
||||
return {
|
||||
"decomposed_sub_questions_dict": sub_questions_dict,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - decompose",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
|
||||
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def deep_answer_generation(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---DEEP GENERATE---")
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
deep_answer_context = state["core_answer_dynamic_context"]
|
||||
|
||||
print(f"Number of verified retrieval docs - deep: {len(docs)}")
|
||||
|
||||
combined_context = normalize_whitespace(
|
||||
COMBINED_CONTEXT.format(
|
||||
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=MODIFIED_RAG_PROMPT.format(
|
||||
question=question, combined_context=combined_context
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
return {
|
||||
"deep_answer": response.content,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - deep answer generation",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
|
||||
|
||||
def dummy_start(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy node to set the start time
|
||||
"""
|
||||
return {"start_time": datetime.now()}
|
||||
@@ -1,51 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def entity_term_extraction(state: QAState) -> dict[str, Any]:
|
||||
"""Extract entities and terms from the question and context"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
doc_context = format_docs(docs)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
_, fast_llm = get_default_llms()
|
||||
# Grader
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
|
||||
# structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
return {
|
||||
"retrieved_entities_relationships": parsed_response,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - entity term extraction",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def final_stuff(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Invokes the agent model to generate a response based on the current state. Given
|
||||
the question, it will decide to retrieve using the retriever tool, or simply end.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the agent response appended to messages
|
||||
"""
|
||||
print("---FINAL---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
messages = state["log_messages"]
|
||||
time_ordered_messages = [x.pretty_repr() for x in messages]
|
||||
time_ordered_messages.sort()
|
||||
|
||||
print("Message Log:")
|
||||
print("\n".join(time_ordered_messages))
|
||||
|
||||
initial_sub_qas = state["initial_sub_qas"]
|
||||
initial_sub_qa_list = []
|
||||
for initial_sub_qa in initial_sub_qas:
|
||||
if initial_sub_qa["sub_answer_check"] == "yes":
|
||||
initial_sub_qa_list.append(
|
||||
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
|
||||
|
||||
log_message = generate_log_message(
|
||||
message="all - final_stuff",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
)
|
||||
|
||||
print(log_message)
|
||||
print("--------------------------------")
|
||||
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
|
||||
print("--------------------------------")
|
||||
|
||||
if not state.get("deep_answer"):
|
||||
print("No Deep Answer was required")
|
||||
return {
|
||||
"log_messages": log_message,
|
||||
}
|
||||
|
||||
deep_answer = state["deep_answer"]
|
||||
sub_qas = state["sub_qas"]
|
||||
sub_qa_list = []
|
||||
for sub_qa in sub_qas:
|
||||
if sub_qa["sub_answer_check"] == "yes":
|
||||
sub_qa_list.append(
|
||||
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
sub_qa_context = "\n".join(sub_qa_list)
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Final Deep Answer:\n{deep_answer}")
|
||||
print("--------------------------------")
|
||||
print("Sub Questions and Answers:")
|
||||
print(sub_qa_context)
|
||||
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="all - final_stuff",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def generate(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm = state["fast_llm"]
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"base_answer": response[0].pretty_repr(),
|
||||
"log_messages": generate_log_message(
|
||||
message="core - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def generate_initial(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE INITIAL---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
print(f"Number of verified retrieval docs - base: {len(docs)}")
|
||||
|
||||
sub_question_answers = state["initial_sub_qas"]
|
||||
|
||||
sub_question_answers_list = []
|
||||
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE = """
|
||||
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
|
||||
"""
|
||||
for sub_question_answer_dict in sub_question_answers:
|
||||
if (
|
||||
sub_question_answer_dict["sub_answer_check"] == "yes"
|
||||
and len(sub_question_answer_dict["sub_answer"]) > 0
|
||||
and sub_question_answer_dict["sub_answer"] != "I don't know"
|
||||
):
|
||||
sub_question_answers_list.append(
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=sub_question_answer_dict["sub_question"],
|
||||
sub_answer=sub_question_answer_dict["sub_answer"],
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_PROMPT.format(
|
||||
question=question,
|
||||
context=format_docs(docs),
|
||||
answered_sub_questions=sub_question_answer_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
return {
|
||||
"base_answer": response.pretty_repr(),
|
||||
"log_messages": generate_log_message(
|
||||
message="core - generate initial",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def main_decomp_base(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Perform an initial question decomposition, incl. one search term
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with initial decomposition
|
||||
"""
|
||||
|
||||
print("---INITIAL DECOMP---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_DECOMPOSITION_PROMPT.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
content = response.pretty_repr()
|
||||
list_of_subquestions = clean_and_parse_list_string(content)
|
||||
|
||||
decomp_list = []
|
||||
|
||||
for sub_question_nr, sub_question in enumerate(list_of_subquestions):
|
||||
sub_question_str = sub_question["sub_question"].strip()
|
||||
# temporarily
|
||||
sub_question_search_queries = [sub_question["search_term"]]
|
||||
|
||||
decomp_list.append(
|
||||
{
|
||||
"sub_question_str": sub_question_str,
|
||||
"sub_question_search_queries": sub_question_search_queries,
|
||||
"sub_question_nr": sub_question_nr,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"initial_sub_questions": decomp_list,
|
||||
"sub_query_start_time": node_start_time,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - initial decomp",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def rewrite(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
qa_state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---STARTING GRAPH---")
|
||||
graph_start_time = datetime.now()
|
||||
|
||||
print("---TRANSFORM QUERY---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
fast_llm = state["fast_llm"]
|
||||
llm_response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
|
||||
|
||||
return {
|
||||
"rewritten_queries": formatted_response.rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=graph_start_time,
|
||||
),
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]:
|
||||
sub_qas = state["sub_qas"]
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
dynamic_context_list = [
|
||||
"Below you will find useful information to answer the original question:"
|
||||
]
|
||||
checked_sub_qas = []
|
||||
|
||||
for core_answer_sub_qa in sub_qas:
|
||||
question = core_answer_sub_qa["sub_question"]
|
||||
answer = core_answer_sub_qa["sub_answer"]
|
||||
verified = core_answer_sub_qa["sub_answer_check"]
|
||||
|
||||
if verified == "yes":
|
||||
dynamic_context_list.append(
|
||||
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
|
||||
)
|
||||
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
|
||||
dynamic_context = "\n".join(dynamic_context_list)
|
||||
|
||||
return {
|
||||
"core_answer_dynamic_context": dynamic_context,
|
||||
"checked_sub_qas": checked_sub_qas,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - sub qa level aggregator",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_qa_manager(state: QAState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_questions_dict = state["decomposed_sub_questions_dict"]
|
||||
|
||||
sub_questions = {}
|
||||
|
||||
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
|
||||
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
|
||||
|
||||
return {
|
||||
"sub_questions": sub_questions,
|
||||
"num_new_question_iterations": 0,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - sub qa manager",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---VERIFY QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm = state["fast_llm"]
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"core - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
INITIAL_DECOMPOSITION_PROMPT = """ \n
|
||||
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
For each sub-question, please also create one search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of json objects with the following format:
|
||||
|
||||
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
INITIAL_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of answered sub-questions - these are very important(!) and definitely should be
|
||||
considered to answer the question.
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
If you don't know the answer or if the provided information is empty or insufficient, just say
|
||||
"I don't know". Do not use your internal knowledge!
|
||||
|
||||
Again, only use the provided informationand do not use your internal knowledge! It is a matter of life
|
||||
and death that you do NOT use your internal knowledge, just the provided information!
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
And here is the question and the provided information:
|
||||
\n
|
||||
\nQuestion:\n {question}
|
||||
|
||||
\nAnswered Sub-questions:\n {answered_sub_questions}
|
||||
|
||||
\nContext:\n {context} \n\n
|
||||
\n\n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ENTITY_TERM_PROMPT = """ \n
|
||||
Based on the original question and the context retieved from a dataset, please generate a list of
|
||||
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
|
||||
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
|
||||
|
||||
\n\n
|
||||
Here is the original question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
And here is the context retrieved:
|
||||
\n ------- \n
|
||||
{context}
|
||||
\n ------- \n
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
|
||||
{{"retrieved_entities_relationships": {{
|
||||
"entities": [{{
|
||||
"entity_name": <assign a name for the entity>,
|
||||
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
|
||||
}}],
|
||||
"relationships": [{{
|
||||
"name": <assign a name for the relationship>,
|
||||
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
|
||||
"entities": [<related entity name 1>, <related entity name 2>]
|
||||
}}],
|
||||
"terms": [{{
|
||||
"term_name": <assign a name for the term>,
|
||||
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
|
||||
"similar_to": <list terms that are similar to this term>
|
||||
}}]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
@@ -1,73 +0,0 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class QAState(TypedDict):
|
||||
# The 'main' state of the answer graph
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
# start time for parallel initial sub-questionn thread
|
||||
sub_query_start_time: datetime
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
rewritten_queries: RewrittenQueries
|
||||
sub_questions: list[dict]
|
||||
initial_sub_questions: list[dict]
|
||||
ranked_subquestion_ids: list[int]
|
||||
decomposed_sub_questions_dict: dict
|
||||
rejected_sub_questions: Annotated[list[str], operator.add]
|
||||
rejected_sub_questions_handled: bool
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
checked_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
base_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
deduped_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
retrieved_entities_relationships: dict
|
||||
questions_context: list[dict]
|
||||
qa_level: int
|
||||
top_chunks: list[InferenceSection]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
num_new_question_iterations: int
|
||||
core_answer_dynamic_context: str
|
||||
dynamic_context: str
|
||||
initial_base_answer: str
|
||||
base_answer: str
|
||||
deep_answer: str
|
||||
|
||||
|
||||
class QAOuputState(TypedDict):
|
||||
# The 'main' output state of the answer graph. Removes all the intermediate states
|
||||
original_question: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_questions: list[dict]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
checked_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
retrieved_entities_relationships: dict
|
||||
top_chunks: list[InferenceSection]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
base_answer: str
|
||||
deep_answer: str
|
||||
|
||||
|
||||
class RetrieverState(TypedDict):
|
||||
# The state for the parallel Retrievers. They each need to see only one query
|
||||
rewritten_query: str
|
||||
graph_start_time: datetime
|
||||
|
||||
|
||||
class VerifierState(TypedDict):
|
||||
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
|
||||
document: InferenceSection
|
||||
question: str
|
||||
graph_start_time: datetime
|
||||
@@ -1,22 +0,0 @@
|
||||
from danswer.agent_search.primary_graph.graph_builder import build_core_graph
|
||||
from danswer.llm.answering.answer import AnswerStream
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def run_graph(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
tools: list[Tool],
|
||||
) -> AnswerStream:
|
||||
graph = build_core_graph()
|
||||
|
||||
inputs = {
|
||||
"original_question": query,
|
||||
"messages": [],
|
||||
"tools": tools,
|
||||
"llm": llm,
|
||||
}
|
||||
compiled_graph = graph.compile()
|
||||
output = compiled_graph.invoke(input=inputs)
|
||||
yield from output
|
||||
@@ -1,16 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class SubQuestions(BaseModel):
|
||||
sub_questions: list[str]
|
||||
@@ -1,342 +0,0 @@
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
|
||||
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
|
||||
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
|
||||
enabling the system to search more broadly.
|
||||
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
|
||||
|
||||
|
||||
REWRITE_PROMPT_MULTI = """ \n
|
||||
Please create a list of 2-3 sample documents that could answer an original question. Each document
|
||||
should be about as long as the original question. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
|
||||
|
||||
BASE_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the context provided below - and only the
|
||||
provided context - to answer the question. If you don't know the answer or if the provided context is
|
||||
empty, just say "I don't know". Do not use your internal knowledge!
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n
|
||||
\n\n
|
||||
Answer:"""
|
||||
|
||||
BASE_CHECK_PROMPT = """ \n
|
||||
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
|
||||
original question requests a simple, factual answer, and there are no ambiguities, judgements,
|
||||
aggregations, or any other complications that may require extra context. (I.e., if the question is
|
||||
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
VERIFIER_PROMPT = """ \n
|
||||
Please check whether the document seems to be relevant for the answer of the original question. Please
|
||||
only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the document text:
|
||||
\n ------- \n
|
||||
{document_content}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
|
||||
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of subquestions:
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_SINGLE = """ \n
|
||||
Please convert an initial user question into a more appropriate search query for retrievel from a
|
||||
document store. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the query: """
|
||||
|
||||
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
|
||||
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
\nQuestion: {question}
|
||||
\nContext: {combined_context} \n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please also provide a search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of json dictionaries with the following format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DECOMPOSE_PROMPT = """ \n
|
||||
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
|
||||
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question.
|
||||
|
||||
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
|
||||
document store, expressed as entities, relationships and terms. You can also think about the types
|
||||
mentioned in brackets
|
||||
|
||||
Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
and or resolve ambiguities
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Don't be too specific unless the original question is specific.
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
#### Consolidations
|
||||
COMBINED_CONTEXT = """-------
|
||||
Below you will find useful information to answer the original question. First, you see a number of
|
||||
sub-questions with their answers. This information should be considered to be more focussed and
|
||||
somewhat more specific to the original question as it tries to contextualized facts.
|
||||
After that will see the documents that were considered to be relevant to answer the original question.
|
||||
|
||||
Here are the sub-questions and their answers:
|
||||
\n\n {deep_answer_context} \n\n
|
||||
\n\n Here are the documents that were considered to be relevant to answer the original question:
|
||||
\n\n {formated_docs} \n\n
|
||||
----------------
|
||||
"""
|
||||
|
||||
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
|
||||
Below you will find a question that we ultimately want to answer (the original question) and a list of
|
||||
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
|
||||
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
|
||||
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
|
||||
motivation and 2 is less relevant.)
|
||||
|
||||
Please rank the motivations in order of relevance for answering the original question. Also, try to
|
||||
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
|
||||
Ultimately, create a list with the motivation numbers where the number of the most relevant
|
||||
motivations comes first.
|
||||
|
||||
Here is the original question:
|
||||
\n\n {original_question} \n\n
|
||||
\n\n Here is the list of sub-question motivations:
|
||||
\n\n {sub_question_explanations} \n\n
|
||||
----------------
|
||||
|
||||
Please think step by step and then generate the ranked list of motivations.
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
{{"reasonning": <explain your reasoning for the ranking>,
|
||||
"ranked_motivations": <ranked list of motivation numbers>}}
|
||||
"""
|
||||
@@ -1,91 +0,0 @@
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
return "\n\n".join(doc.combined_content for doc in docs)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return ast.literal_eval(cleaned_string)
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
|
||||
entities = entity_term_extraction_dict["entities"]
|
||||
terms = entity_term_extraction_dict["terms"]
|
||||
relationships = entity_term_extraction_dict["relationships"]
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def generate_log_message(
|
||||
message: str,
|
||||
node_start_time: datetime,
|
||||
graph_start_time: datetime | None = None,
|
||||
) -> str:
|
||||
current_time = datetime.now()
|
||||
|
||||
if graph_start_time is not None:
|
||||
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
||||
else:
|
||||
graph_time_str = "N/A"
|
||||
|
||||
node_time_str = _format_time_delta(current_time - node_start_time)
|
||||
|
||||
return f"{graph_time_str} ({node_time_str} s): {message}"
|
||||
@@ -58,7 +58,6 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -132,11 +131,12 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
if AUTH_TYPE == AuthType.BASIC:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
# For other auth types, if the user is authenticated it's assumed that
|
||||
# the user is already verified via the external IDP
|
||||
return False
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
|
||||
@@ -6,27 +6,27 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.build import default_build_system_message
|
||||
from danswer.chat.prompt_builder.build import default_build_user_message
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.chat.stream_processing.utils import map_document_id_order
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
@@ -26,7 +26,7 @@ from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
|
||||
@@ -1,58 +1,22 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: "ToolResponseHandler",
|
||||
answer_handler: "AnswerResponseHandler",
|
||||
tool_handler: ToolResponseHandler,
|
||||
answer_handler: AnswerResponseHandler,
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
@@ -1,10 +1,14 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -12,8 +16,15 @@ from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import Prompt
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
@@ -210,3 +221,109 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
@@ -6,19 +6,24 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.answer import Answer
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import create_temporary_persona
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
@@ -57,16 +62,11 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.utils import load_all_chat_files
|
||||
from danswer.file_store.utils import save_files_from_urls
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.file_store.utils import save_files
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
@@ -119,6 +119,7 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -302,6 +303,7 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
@@ -678,7 +680,8 @@ def stream_chat_message_objects(
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
# any files to associate with the AI message e.g. dall-e generated images
|
||||
ai_message_files = []
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
@@ -733,8 +736,14 @@ def stream_chat_message_objects(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
@@ -760,15 +769,19 @@ def stream_chat_message_objects(
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=(
|
||||
ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV
|
||||
),
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
@@ -818,7 +831,8 @@ def stream_chat_message_objects(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
|
||||
@@ -4,20 +4,26 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
@@ -139,3 +145,15 @@ class AnswerPromptBuilder:
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -2,12 +2,12 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
@@ -1,10 +1,10 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
62
backend/danswer/chat/prompt_builder/utils.py
Normal file
62
backend/danswer/chat/prompt_builder/utils.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
@@ -5,16 +5,16 @@ from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import ContextualPruningConfig
|
||||
from danswer.chat.models import (
|
||||
LlmDoc,
|
||||
)
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
@@ -3,13 +3,11 @@ from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.llm_response_handler import ResponsePart
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -4,8 +4,8 @@ from collections.abc import Generator
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -4,8 +4,8 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
@@ -43,9 +43,6 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
@@ -85,6 +82,7 @@ OAUTH_CLIENT_SECRET = (
|
||||
)
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
@@ -118,6 +116,8 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
|
||||
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
|
||||
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
# the number of times to try and connect to vespa on startup before giving up
|
||||
VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10)
|
||||
|
||||
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
|
||||
|
||||
|
||||
@@ -2,6 +2,8 @@ import json
|
||||
import os
|
||||
|
||||
|
||||
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
|
||||
|
||||
# if specified, will pass through request headers to the call to API calls made by custom tools
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
|
||||
|
||||
@@ -15,6 +15,7 @@ from danswer.connectors.confluence.utils import attachment_to_content
|
||||
from danswer.connectors.confluence.utils import build_confluence_document_id
|
||||
from danswer.connectors.confluence.utils import datetime_from_string
|
||||
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from danswer.connectors.confluence.utils import validate_attachment_filetype
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@@ -276,9 +277,11 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
):
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
@@ -288,7 +291,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
@@ -298,6 +301,21 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
if not validate_attachment_filetype(attachment):
|
||||
continue
|
||||
attachment_restrictions = attachment.get("restrictions")
|
||||
if not attachment_restrictions:
|
||||
attachment_restrictions = page_restrictions
|
||||
|
||||
attachment_space_key = attachment.get("space", {}).get("key")
|
||||
if not attachment_space_key:
|
||||
attachment_space_key = page_space_key
|
||||
|
||||
attachment_perm_sync_data = {
|
||||
"restrictions": attachment_restrictions or {},
|
||||
"space_key": attachment_space_key,
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
@@ -305,7 +323,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=perm_sync_data,
|
||||
perm_sync_data=attachment_perm_sync_data,
|
||||
)
|
||||
)
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
|
||||
@@ -177,19 +177,23 @@ def extract_text_from_confluence_html(
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
return attachment["metadata"]["mediaType"] not in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]:
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
@@ -245,7 +249,7 @@ def build_confluence_document_id(
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachments in use
|
||||
|
||||
|
||||
@@ -5,7 +5,11 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.chat.prune_and_merge import _merge_sections
|
||||
from danswer.chat.prune_and_merge import ChunkRange
|
||||
from danswer.chat.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
@@ -27,10 +31,6 @@ from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prune_and_merge import _merge_sections
|
||||
from danswer.llm.answering.prune_and_merge import ChunkRange
|
||||
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -4,6 +4,8 @@ schema DANSWER_CHUNK_NAME {
|
||||
# Not to be confused with the UUID generated for this chunk which is called documentid by default
|
||||
field document_id type string {
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
rank: filter
|
||||
}
|
||||
field chunk_id type int {
|
||||
indexing: summary | attribute
|
||||
|
||||
@@ -6,6 +6,7 @@ import zipfile
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from email.parser import Parser as EmailParser
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
@@ -15,13 +16,17 @@ import chardet
|
||||
import docx # type: ignore
|
||||
import openpyxl # type: ignore
|
||||
import pptx # type: ignore
|
||||
from docx import Document
|
||||
from fastapi import UploadFile
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
from danswer.configs.constants import DANSWER_METADATA_FILENAME
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.file_store.file_store import FileStore
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -375,3 +380,35 @@ def extract_file_text(
|
||||
) from e
|
||||
logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
def convert_docx_to_txt(
|
||||
file: UploadFile, file_store: FileStore, file_path: str
|
||||
) -> None:
|
||||
file.file.seek(0)
|
||||
docx_content = file.file.read()
|
||||
doc = Document(BytesIO(docx_content))
|
||||
|
||||
# Extract text from the document
|
||||
full_text = []
|
||||
for para in doc.paragraphs:
|
||||
full_text.append(para.text)
|
||||
|
||||
# Join the extracted text
|
||||
text_content = "\n".join(full_text)
|
||||
|
||||
txt_file_path = docx_to_txt_filename(file_path)
|
||||
file_store.save_file(
|
||||
file_name=txt_file_path,
|
||||
content=BytesIO(text_content.encode("utf-8")),
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
file_type="text/plain",
|
||||
)
|
||||
|
||||
|
||||
def docx_to_txt_filename(file_path: str) -> str:
|
||||
"""
|
||||
Convert a .docx file path to its corresponding .txt file path.
|
||||
"""
|
||||
return file_path.rsplit(".", 1)[0] + ".txt"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -13,8 +13,8 @@ from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.b64 import get_image_type
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
@@ -75,11 +75,58 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
return unique_id
|
||||
|
||||
|
||||
def save_files_from_urls(urls: list[str]) -> list[str]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unique_id = str(uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=unique_id,
|
||||
content=BytesIO(base64.b64decode(base64_string)),
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type=get_image_type(base64_string),
|
||||
)
|
||||
return unique_id
|
||||
|
||||
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
|
||||
(save_file_from_url, (url, tenant_id)) for url in urls
|
||||
|
||||
def save_file(
|
||||
tenant_id: str,
|
||||
url: str | None = None,
|
||||
base64_data: str | None = None,
|
||||
) -> str:
|
||||
"""Save a file from either a URL or base64 encoded string.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to save the file under
|
||||
url: URL to download file from
|
||||
base64_data: Base64 encoded file data
|
||||
|
||||
Returns:
|
||||
The unique ID of the saved file
|
||||
|
||||
Raises:
|
||||
ValueError: If neither url nor base64_data is provided, or if both are provided
|
||||
"""
|
||||
if url is not None and base64_data is not None:
|
||||
raise ValueError("Cannot specify both url and base64_data")
|
||||
|
||||
if url is not None:
|
||||
return save_file_from_url(url, tenant_id)
|
||||
elif base64_data is not None:
|
||||
return save_file_from_base64(base64_data, tenant_id)
|
||||
else:
|
||||
raise ValueError("Must specify either url or base64_data")
|
||||
|
||||
|
||||
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
|
||||
# NOTE: be explicit about typing so that if we change things, we get notified
|
||||
funcs: list[
|
||||
tuple[
|
||||
Callable[[str, str | None, str | None], str],
|
||||
tuple[str, str | None, str | None],
|
||||
]
|
||||
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
|
||||
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
|
||||
]
|
||||
# Must pass in tenant_id here, since this is called by multithreading
|
||||
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
||||
@@ -1,163 +0,0 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Prompt
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
|
||||
) -> "PreviousMessage":
|
||||
message_file_ids = (
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
)
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
files=[
|
||||
file
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
elif self.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
else:
|
||||
return SystemMessage(content=content)
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
@@ -1,20 +0,0 @@
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
59
backend/danswer/llm/models.py
Normal file
59
backend/danswer/llm/models.py
Normal file
@@ -0,0 +1,59 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
|
||||
) -> "PreviousMessage":
|
||||
message_file_ids = (
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
)
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
files=[
|
||||
file
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
elif self.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
else:
|
||||
return SystemMessage(content=content)
|
||||
@@ -5,8 +5,6 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
import litellm # type: ignore
|
||||
import pandas as pd
|
||||
@@ -36,17 +34,15 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.utils.b64 import get_image_type
|
||||
from danswer.utils.b64 import get_image_type_from_bytes
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -104,39 +100,6 @@ def litellm_exception_to_error_msg(
|
||||
return error_msg
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
|
||||
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
|
||||
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
|
||||
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
|
||||
@@ -190,6 +153,7 @@ def build_content_with_imgs(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
@@ -202,6 +166,7 @@ def build_content_with_imgs(
|
||||
)
|
||||
|
||||
img_urls = img_urls or []
|
||||
b64_imgs = b64_imgs or []
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
@@ -220,11 +185,22 @@ def build_content_with_imgs(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:image/jpeg;base64,{file.to_base64()}",
|
||||
"url": (
|
||||
f"data:{get_image_type_from_bytes(file.content)};"
|
||||
f"base64,{file.to_base64()}"
|
||||
),
|
||||
},
|
||||
}
|
||||
for file in files
|
||||
if file.file_type == "image"
|
||||
for file in img_files
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
|
||||
},
|
||||
}
|
||||
for b64_img in b64_imgs
|
||||
]
|
||||
+ [
|
||||
{
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
|
||||
@@ -3,14 +3,14 @@ from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.prompt_builder.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE
|
||||
from danswer.prompts.chat_prompts import NO_SEARCH
|
||||
from danswer.prompts.chat_prompts import REQUIRE_SEARCH_HINT
|
||||
|
||||
@@ -4,10 +4,10 @@ from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_REPHRASE
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import HISTORY_QUERY_REPHRASE
|
||||
|
||||
@@ -86,6 +86,7 @@ from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.file_processing.extract_file_text import convert_docx_to_txt
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
@@ -393,6 +394,12 @@ def upload_files(
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
file_type=file.content_type or "text/plain",
|
||||
)
|
||||
|
||||
if file.content_type and file.content_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
convert_docx_to_txt(file, file_store, file_path)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return FileUploadResponse(file_paths=deduped_file_paths)
|
||||
@@ -1010,37 +1017,18 @@ def get_connector_by_id(
|
||||
|
||||
|
||||
class BasicCCPairInfo(BaseModel):
|
||||
docs_indexed: int
|
||||
has_successful_run: bool
|
||||
source: DocumentSource
|
||||
|
||||
|
||||
@router.get("/indexing-status")
|
||||
@router.get("/connector-status")
|
||||
def get_basic_connector_indexing_status(
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BasicCCPairInfo]:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_identifiers = [
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
]
|
||||
document_count_info = get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=cc_pair_identifiers,
|
||||
)
|
||||
cc_pair_to_document_cnt = {
|
||||
(connector_id, credential_id): cnt
|
||||
for connector_id, credential_id, cnt in document_count_info
|
||||
}
|
||||
return [
|
||||
BasicCCPairInfo(
|
||||
docs_indexed=cc_pair_to_document_cnt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id)
|
||||
)
|
||||
or 0,
|
||||
has_successful_run=cc_pair.last_successful_index_time is not None,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.prompt_builder.utils import build_dummy_prompt
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.engine import get_session
|
||||
@@ -33,7 +34,6 @@ from danswer.db.persona import update_persona_shared_users
|
||||
from danswer.db.persona import update_persona_visibility
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import ImageGenerationToolStatus
|
||||
from danswer.server.features.persona.models import PersonaCategoryCreate
|
||||
|
||||
@@ -194,11 +194,11 @@ def bulk_invite_users(
|
||||
)
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
normalized_emails = []
|
||||
new_invited_emails = []
|
||||
try:
|
||||
for email in emails:
|
||||
email_info = validate_email(email)
|
||||
normalized_emails.append(email_info.normalized) # type: ignore
|
||||
new_invited_emails.append(email_info.normalized)
|
||||
|
||||
except (EmailUndeliverableError, EmailNotValidError) as e:
|
||||
raise HTTPException(
|
||||
@@ -210,7 +210,7 @@ def bulk_invite_users(
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
)(normalized_emails, tenant_id)
|
||||
)(new_invited_emails, tenant_id)
|
||||
|
||||
except IntegrityError as e:
|
||||
if isinstance(e.orig, UniqueViolation):
|
||||
@@ -224,7 +224,7 @@ def bulk_invite_users(
|
||||
|
||||
initial_invited_users = get_invited_users()
|
||||
|
||||
all_emails = list(set(normalized_emails) | set(initial_invited_users))
|
||||
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
|
||||
number_of_invited_users = write_invited_users(all_emails)
|
||||
|
||||
if not MULTI_TENANT:
|
||||
@@ -236,7 +236,7 @@ def bulk_invite_users(
|
||||
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in all_emails:
|
||||
for email in new_invited_emails:
|
||||
send_user_email_invite(email, current_user)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
@@ -250,7 +250,7 @@ def bulk_invite_users(
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(normalized_emails, tenant_id)
|
||||
)(new_invited_emails, tenant_id)
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
@@ -23,6 +24,9 @@ from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import extract_headers
|
||||
from danswer.chat.process_message import stream_chat_message
|
||||
from danswer.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -47,13 +51,11 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
@@ -718,6 +720,18 @@ def fetch_chat_file(
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
original_file_name = file_record.display_name
|
||||
if file_record.file_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
# Check if a converted text file exists for .docx files
|
||||
txt_file_name = docx_to_txt_filename(original_file_name)
|
||||
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
|
||||
txt_file_record = file_store.read_file_record(txt_file_id)
|
||||
if txt_file_record:
|
||||
file_record = txt_file_record
|
||||
file_id = txt_file_id
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -22,6 +23,9 @@ from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
source: DocumentSource
|
||||
|
||||
@@ -4,6 +4,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import MANAGED_VESPA
|
||||
from danswer.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
@@ -221,13 +222,13 @@ def setup_vespa(
|
||||
document_index: DocumentIndex,
|
||||
index_setting: IndexingSetting,
|
||||
secondary_index_setting: IndexingSetting | None,
|
||||
num_attempts: int = VESPA_NUM_ATTEMPTS_ON_STARTUP,
|
||||
) -> bool:
|
||||
# Vespa startup is a bit slow, so give it a few seconds
|
||||
WAIT_SECONDS = 5
|
||||
VESPA_ATTEMPTS = 5
|
||||
for x in range(VESPA_ATTEMPTS):
|
||||
for x in range(num_attempts):
|
||||
try:
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{num_attempts})...")
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=index_setting.model_dim,
|
||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||
@@ -244,7 +245,7 @@ def setup_vespa(
|
||||
time.sleep(WAIT_SECONDS)
|
||||
|
||||
logger.error(
|
||||
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
|
||||
f"Vespa setup did not succeed. Attempt limit reached. ({num_attempts})"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CustomToolCallSummary,
|
||||
)
|
||||
|
||||
@@ -25,9 +25,6 @@ class ToolCallSummary(BaseModel__v1):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def tool_call_tokens(
|
||||
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer
|
||||
|
||||
@@ -3,13 +3,13 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
@@ -5,6 +5,10 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
@@ -19,10 +23,6 @@ from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
|
||||
@@ -15,14 +15,14 @@ from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.tools.base_tool import BaseTool
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
|
||||
@@ -4,14 +4,16 @@ from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from litellm import image_generation # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
@@ -56,9 +58,18 @@ Follow Up Input:
|
||||
""".strip()
|
||||
|
||||
|
||||
class ImageFormat(str, Enum):
|
||||
URL = "url"
|
||||
BASE64 = "b64_json"
|
||||
|
||||
|
||||
_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT)
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
revised_prompt: str
|
||||
url: str
|
||||
url: str | None
|
||||
image_data: str | None
|
||||
|
||||
|
||||
class ImageShape(str, Enum):
|
||||
@@ -80,6 +91,7 @@ class ImageGenerationTool(Tool):
|
||||
model: str = "dall-e-3",
|
||||
num_imgs: int = 2,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
@@ -89,6 +101,7 @@ class ImageGenerationTool(Tool):
|
||||
self.num_imgs = num_imgs
|
||||
|
||||
self.additional_headers = additional_headers
|
||||
self.output_format = output_format
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -168,7 +181,7 @@ class ImageGenerationTool(Tool):
|
||||
)
|
||||
|
||||
return build_content_with_imgs(
|
||||
json.dumps(
|
||||
message=json.dumps(
|
||||
[
|
||||
{
|
||||
"revised_prompt": image_generation.revised_prompt,
|
||||
@@ -177,13 +190,10 @@ class ImageGenerationTool(Tool):
|
||||
for image_generation in image_generations
|
||||
]
|
||||
),
|
||||
# NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow
|
||||
# Tool messages to contain images
|
||||
# img_urls=[image_generation.url for image_generation in image_generations],
|
||||
)
|
||||
|
||||
def _generate_image(
|
||||
self, prompt: str, shape: ImageShape
|
||||
self, prompt: str, shape: ImageShape, format: ImageFormat
|
||||
) -> ImageGenerationResponse:
|
||||
if shape == ImageShape.LANDSCAPE:
|
||||
size = "1792x1024"
|
||||
@@ -197,20 +207,32 @@ class ImageGenerationTool(Tool):
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
api_key=self.api_key,
|
||||
# need to pass in None rather than empty str
|
||||
api_base=self.api_base or None,
|
||||
api_version=self.api_version or None,
|
||||
size=size,
|
||||
n=1,
|
||||
response_format=format,
|
||||
extra_headers=build_llm_extra_headers(self.additional_headers),
|
||||
)
|
||||
|
||||
if format == ImageFormat.URL:
|
||||
url = response.data[0]["url"]
|
||||
image_data = None
|
||||
else:
|
||||
url = None
|
||||
image_data = response.data[0]["b64_json"]
|
||||
|
||||
return ImageGenerationResponse(
|
||||
revised_prompt=response.data[0]["revised_prompt"],
|
||||
url=response.data[0]["url"],
|
||||
url=url,
|
||||
image_data=image_data,
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error fetching or converting image: {e}")
|
||||
raise ValueError("Failed to fetch or convert the generated image")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error occured during image generation: {e}")
|
||||
logger.debug(f"Error occurred during image generation: {e}")
|
||||
|
||||
error_message = str(e)
|
||||
if "OpenAIException" in str(type(e)):
|
||||
@@ -235,9 +257,8 @@ class ImageGenerationTool(Tool):
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
prompt = cast(str, kwargs["prompt"])
|
||||
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
||||
format = self.output_format
|
||||
|
||||
# dalle3 only supports 1 image at a time, which is why we have to
|
||||
# parallelize this via threading
|
||||
results = cast(
|
||||
list[ImageGenerationResponse],
|
||||
run_functions_tuples_in_parallel(
|
||||
@@ -247,6 +268,7 @@ class ImageGenerationTool(Tool):
|
||||
(
|
||||
prompt,
|
||||
shape,
|
||||
format,
|
||||
),
|
||||
)
|
||||
for _ in range(self.num_imgs)
|
||||
@@ -288,11 +310,17 @@ class ImageGenerationTool(Tool):
|
||||
if img_generation_response is None:
|
||||
raise ValueError("No image generation response found")
|
||||
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
img_urls = [img.url for img in img_generation_response if img.url is not None]
|
||||
b64_imgs = [
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data is not None
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
img_urls=img_urls,
|
||||
b64_imgs=b64_imgs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -11,11 +11,14 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str, img_urls: list[str] | None = None
|
||||
query: str,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
b64_imgs=b64_imgs,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -7,15 +7,15 @@ from typing import cast
|
||||
import httpx
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
|
||||
@@ -7,10 +7,19 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.llm_response_handler import LLMCall
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import ContextualPruningConfig
|
||||
from danswer.chat.models import DanswerContext
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.chat.prune_and_merge import prune_and_merge_sections
|
||||
from danswer.chat.prune_and_merge import prune_sections
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -25,17 +34,8 @@ from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.llm.answering.prune_and_merge import prune_and_merge_sections
|
||||
from danswer.llm.answering.prune_and_merge import prune_sections
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
|
||||
@@ -2,15 +2,15 @@ from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.citations_prompt import (
|
||||
build_citations_system_message,
|
||||
)
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
|
||||
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
|
||||
from danswer.chat.prompt_builder.citations_prompt import build_citations_user_message
|
||||
from danswer.chat.prompt_builder.quotes_prompt import build_quotes_user_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Any
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
25
backend/danswer/utils/b64.py
Normal file
25
backend/danswer/utils/b64.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import base64
|
||||
|
||||
|
||||
def get_image_type_from_bytes(raw_b64_bytes: bytes) -> str:
|
||||
magic_number = raw_b64_bytes[:4]
|
||||
|
||||
if magic_number.startswith(b"\x89PNG"):
|
||||
mime_type = "image/png"
|
||||
elif magic_number.startswith(b"\xFF\xD8"):
|
||||
mime_type = "image/jpeg"
|
||||
elif magic_number.startswith(b"GIF8"):
|
||||
mime_type = "image/gif"
|
||||
elif magic_number.startswith(b"RIFF") and raw_b64_bytes[8:12] == b"WEBP":
|
||||
mime_type = "image/webp"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported image format - only PNG, JPEG, " "GIF, and WEBP are supported."
|
||||
)
|
||||
|
||||
return mime_type
|
||||
|
||||
|
||||
def get_image_type(raw_b64_string: str) -> str:
|
||||
binary_data = base64.b64decode(raw_b64_string)
|
||||
return get_image_type_from_bytes(binary_data)
|
||||
@@ -11,6 +11,14 @@ SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/ee/danswer/configs/saml
|
||||
#####
|
||||
# Auto Permission Sync
|
||||
#####
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
# In seconds, default is 5 minutes
|
||||
CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int(
|
||||
os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60
|
||||
)
|
||||
NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2)
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user