mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 15:55:45 +00:00
Compare commits
8 Commits
initial_im
...
jwt_featur
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c7f787436d | ||
|
|
75a70a8169 | ||
|
|
c876bacaaf | ||
|
|
f103aa8ad4 | ||
|
|
8b875cde91 | ||
|
|
38a8a3afe3 | ||
|
|
ab47a05606 | ||
|
|
df9627f511 |
@@ -24,8 +24,6 @@ env:
|
||||
GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR: ${{ secrets.GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR }}
|
||||
GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR: ${{ secrets.GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR }}
|
||||
# Slab
|
||||
SLAB_BOT_TOKEN: ${{ secrets.SLAB_BOT_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
|
||||
@@ -73,7 +73,6 @@ RUN apt-get update && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
|
||||
# Pre-downloading models for setups with limited egress
|
||||
RUN python -c "from tokenizers import Tokenizer; \
|
||||
Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')"
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Literal
|
||||
from typing import Any
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
@@ -8,7 +8,6 @@ from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
@@ -36,18 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Combine Search and Chat
|
||||
|
||||
Revision ID: 9f696734098f
|
||||
Revises: a8c2065484e6
|
||||
Create Date: 2024-11-27 15:32:19.694972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f696734098f"
|
||||
down_revision = "a8c2065484e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "description", nullable=True)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.drop_column("slack_channel_config", "response_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("chat_session", "description", nullable=False)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column(
|
||||
"response_type", sa.String(), nullable=False, server_default="citations"
|
||||
),
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
"""add auto scroll to user model
|
||||
|
||||
Revision ID: a8c2065484e6
|
||||
Revises: abe7378b8217
|
||||
Create Date: 2024-11-22 17:34:09.690295
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a8c2065484e6"
|
||||
down_revision = "abe7378b8217"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "auto_scroll")
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -38,15 +37,8 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -18,11 +18,6 @@ class ExternalAccess:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
@@ -1,42 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
|
||||
|
||||
def sub_continue_to_verifier(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes each de-douped retrieved doc to the verifier step - in parallel
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
|
||||
return [
|
||||
Send(
|
||||
"sub_verifier",
|
||||
VerifierState(
|
||||
document=doc,
|
||||
#question=state["original_question"],
|
||||
question=state["sub_question_str"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for doc in state["sub_question_deduped_retrieval_docs"]
|
||||
]
|
||||
|
||||
|
||||
def sub_continue_to_retrieval(state: BaseQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
rewritten_queries = state["sub_question_search_queries"].rewritten_queries + [state["sub_question_str"]]
|
||||
return [
|
||||
Send(
|
||||
"sub_custom_retrieve",
|
||||
RetrieverState(
|
||||
rewritten_query=query,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for query in rewritten_queries
|
||||
]
|
||||
@@ -1,132 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_retrieval
|
||||
from danswer.agent_search.core_qa_graph.edges import sub_continue_to_verifier
|
||||
from danswer.agent_search.core_qa_graph.nodes.combine_retrieved_docs import (
|
||||
sub_combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.custom_retrieve import (
|
||||
sub_custom_retrieve,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.dummy import sub_dummy
|
||||
from danswer.agent_search.core_qa_graph.nodes.final_format import (
|
||||
sub_final_format,
|
||||
)
|
||||
from danswer.agent_search.core_qa_graph.nodes.generate import sub_generate
|
||||
from danswer.agent_search.core_qa_graph.nodes.qa_check import sub_qa_check
|
||||
from danswer.agent_search.core_qa_graph.nodes.rewrite import sub_rewrite
|
||||
from danswer.agent_search.core_qa_graph.nodes.verifier import sub_verifier
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAOutputState
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.core_qa_graph.states import CoreQAInputState
|
||||
|
||||
|
||||
def build_core_qa_graph() -> StateGraph:
|
||||
sub_answers_initial = StateGraph(
|
||||
state_schema=BaseQAState,
|
||||
output=BaseQAOutputState,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
sub_answers_initial.add_node(node="sub_dummy", action=sub_dummy)
|
||||
sub_answers_initial.add_node(node="sub_rewrite", action=sub_rewrite)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_custom_retrieve",
|
||||
action=sub_custom_retrieve,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_combine_retrieved_docs",
|
||||
action=sub_combine_retrieved_docs,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_verifier",
|
||||
action=sub_verifier,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_generate",
|
||||
action=sub_generate,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_qa_check",
|
||||
action=sub_qa_check,
|
||||
)
|
||||
sub_answers_initial.add_node(
|
||||
node="sub_final_format",
|
||||
action=sub_final_format,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
sub_answers_initial.add_edge(START, "sub_dummy")
|
||||
sub_answers_initial.add_edge("sub_dummy", "sub_rewrite")
|
||||
|
||||
sub_answers_initial.add_conditional_edges(
|
||||
source="sub_rewrite",
|
||||
path=sub_continue_to_retrieval,
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_custom_retrieve",
|
||||
end_key="sub_combine_retrieved_docs",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_conditional_edges(
|
||||
source="sub_combine_retrieved_docs",
|
||||
path=sub_continue_to_verifier,
|
||||
path_map=["sub_verifier"],
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_verifier",
|
||||
end_key="sub_generate",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_generate",
|
||||
end_key="sub_qa_check",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_qa_check",
|
||||
end_key="sub_final_format",
|
||||
)
|
||||
|
||||
sub_answers_initial.add_edge(
|
||||
start_key="sub_final_format",
|
||||
end_key=END,
|
||||
)
|
||||
# sub_answers_graph = sub_answers_initial.compile()
|
||||
return sub_answers_initial
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# q = "Whose music is kind of hard to easily enjoy?"
|
||||
# q = "What is voice leading?"
|
||||
# q = "What are the types of motions in music?"
|
||||
# q = "What are key elements of music theory?"
|
||||
# q = "How can I best understand music theory using voice leading?"
|
||||
q = "What makes good music?"
|
||||
# q = "types of motions in music"
|
||||
# q = "What is the relationship between music and physics?"
|
||||
# q = "Can you compare various grunge styles?"
|
||||
# q = "Why is quantum gravity so hard?"
|
||||
|
||||
inputs = CoreQAInputState(
|
||||
original_question=q,
|
||||
sub_question_str=q,
|
||||
)
|
||||
sub_answers_graph = build_core_qa_graph()
|
||||
compiled_sub_answers = sub_answers_graph.compile()
|
||||
output = compiled_sub_answers.invoke(inputs)
|
||||
print("\nOUTPUT:")
|
||||
print(output.keys())
|
||||
for key, value in output.items():
|
||||
if key in [
|
||||
"sub_question_answer",
|
||||
"sub_question_str",
|
||||
"sub_qas",
|
||||
"initial_sub_qas",
|
||||
"sub_question_answer",
|
||||
]:
|
||||
print(f"{key}: {value}")
|
||||
@@ -1,36 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def sub_combine_retrieved_docs(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
|
||||
dedupe_docs: list[InferenceSection] = []
|
||||
for base_retrieval_doc in sub_question_base_retrieval_docs:
|
||||
if not any(
|
||||
base_retrieval_doc.center_chunk.chunk_id == doc.center_chunk.chunk_id
|
||||
for doc in dedupe_docs
|
||||
):
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
|
||||
return {
|
||||
"sub_question_deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,66 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE SUB---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
rewritten_query = state["rewritten_query"]
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
documents: list[InferenceSection] = []
|
||||
llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
documents = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=rewritten_query,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
reranked_docs = documents.reranked_sections
|
||||
|
||||
# initial metric to measure fit TODO: implement metric properly
|
||||
|
||||
top_1_score = reranked_docs[0].center_chunk.score
|
||||
top_5_score = sum([doc.center_chunk.score for doc in reranked_docs[:5]]) / 5
|
||||
top_10_score = sum([doc.center_chunk.score for doc in reranked_docs[:10]]) / 10
|
||||
|
||||
fit_score = 1/3 * (top_1_score + top_5_score + top_10_score)
|
||||
|
||||
chunk_ids = {'query': rewritten_query,
|
||||
'chunk_ids': [doc.center_chunk.chunk_id for doc in reranked_docs]}
|
||||
|
||||
|
||||
return {
|
||||
"sub_question_base_retrieval_docs": reranked_docs,
|
||||
"sub_chunk_ids": [chunk_ids],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - custom_retrieve, fit_score: {fit_score}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,24 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy step
|
||||
"""
|
||||
|
||||
print("---Sub Dummy---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
return {
|
||||
"graph_start_time": node_start_time,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - dummy",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=node_start_time,
|
||||
),
|
||||
}
|
||||
@@ -1,22 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
|
||||
|
||||
def sub_final_format(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Create the final output for the QA subgraph
|
||||
"""
|
||||
|
||||
print("---BASE FINAL FORMAT---")
|
||||
|
||||
return {
|
||||
"sub_qas": [
|
||||
{
|
||||
"sub_question": state["sub_question_str"],
|
||||
"sub_answer": state["sub_question_answer"],
|
||||
"sub_answer_check": state["sub_question_answer_check"],
|
||||
}
|
||||
],
|
||||
"log_messages": state["log_messages"],
|
||||
}
|
||||
@@ -1,91 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_generate(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
|
||||
# Create sub-query results
|
||||
|
||||
verified_chunks = [chunk.center_chunk.chunk_id for chunk in state["sub_question_verified_retrieval_docs"]]
|
||||
result_dict = {}
|
||||
|
||||
chunk_id_dicts = state["sub_chunk_ids"]
|
||||
expanded_chunks = []
|
||||
original_chunks = []
|
||||
|
||||
for chunk_id_dict in chunk_id_dicts:
|
||||
sub_question = chunk_id_dict['query']
|
||||
verified_sq_chunks = [chunk_id for chunk_id in chunk_id_dict['chunk_ids'] if chunk_id in verified_chunks]
|
||||
|
||||
if sub_question != state["original_question"]:
|
||||
expanded_chunks += verified_sq_chunks
|
||||
else:
|
||||
result_dict['ORIGINAL'] = len(verified_sq_chunks)
|
||||
original_chunks += verified_sq_chunks
|
||||
result_dict[sub_question[:30]] = len(verified_sq_chunks)
|
||||
|
||||
expansion_chunks = set(expanded_chunks)
|
||||
num_expansion_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id in verified_chunks])
|
||||
num_original_relevant_chunks = len(original_chunks)
|
||||
num_missed_relevant_chunks = sum([1 for chunk_id in original_chunks if chunk_id not in expansion_chunks])
|
||||
num_gained_relevant_chunks = sum([1 for chunk_id in expansion_chunks if chunk_id not in original_chunks])
|
||||
result_dict['expansion_chunks'] = num_expansion_chunks
|
||||
|
||||
|
||||
|
||||
print(result_dict)
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question_str"]
|
||||
docs = state["sub_question_verified_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
# Only take the top 10 docs.
|
||||
# TODO: Make this dynamic or use config param?
|
||||
top_10_docs = docs[-10:]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(top_10_docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
_, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
return {
|
||||
"sub_question_answer": answer_str,
|
||||
"log_messages": generate_log_message(
|
||||
message="base - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_qa_check(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Check if the sub-question answer is satisfactory.
|
||||
|
||||
Args:
|
||||
state: The current SubQAState containing the sub-question and its answer
|
||||
|
||||
Returns:
|
||||
dict containing the check result and log message
|
||||
"""
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_CHECK_PROMPT.format(
|
||||
question=state["sub_question_str"],
|
||||
base_answer=state["sub_question_answer"],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
_, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
response_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
return {
|
||||
"sub_question_answer_check": response_str,
|
||||
"base_answer_messages": generate_log_message(
|
||||
message="sub - qa_check",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import (
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL,
|
||||
)
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_rewrite(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
|
||||
print("---SUB TRANSFORM QUERY---")
|
||||
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
# messages = state["base_answer_messages"]
|
||||
question = state["sub_question_str"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI_ORIGINAL.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
"""
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
"""
|
||||
|
||||
_, fast_llm = get_default_llms()
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
|
||||
# structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
print(f"llm_response: {llm_response}")
|
||||
|
||||
rewritten_queries = llm_response.split("--")
|
||||
# rewritten_queries = [llm_response.split("\n")[0]]
|
||||
|
||||
print(f"rewritten_queries: {rewritten_queries}")
|
||||
|
||||
rewritten_queries = RewrittenQueries(rewritten_queries=rewritten_queries)
|
||||
|
||||
return {
|
||||
"sub_question_search_queries": rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def sub_verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
# print("---VERIFY QUTPUT---")
|
||||
node_start_time = datetime.datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm, fast_llm = get_default_llms()
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
response_string = merge_message_runs(response, chunk_separator="")[0].content
|
||||
# Convert string response to proper dictionary format
|
||||
decision_dict = {"decision": response_string.lower()}
|
||||
formatted_response = BinaryDecision.model_validate(decision_dict)
|
||||
|
||||
print(f"Verification end time: {datetime.datetime.now()}")
|
||||
|
||||
return {
|
||||
"sub_question_verified_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,90 +0,0 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
class SubQuestionRetrieverState(TypedDict):
|
||||
# The state for the parallel Retrievers. They each need to see only one query
|
||||
sub_question_rewritten_query: str
|
||||
|
||||
|
||||
class SubQuestionVerifierState(TypedDict):
|
||||
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
|
||||
sub_question_document: InferenceSection
|
||||
sub_question: str
|
||||
|
||||
|
||||
class CoreQAInputState(TypedDict):
|
||||
sub_question_str: str
|
||||
original_question: str
|
||||
|
||||
|
||||
class BaseQAState(TypedDict):
|
||||
# The 'core SubQuestion' state.
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
# start time for parallel initial sub-questionn thread
|
||||
sub_query_start_time: datetime
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question_str: str
|
||||
sub_question_search_queries: RewrittenQueries
|
||||
sub_question_nr: int
|
||||
sub_chunk_ids: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
# Answers sent back to core
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
|
||||
class BaseQAOutputState(TypedDict):
|
||||
# The 'SubQuestion' output state. Removes all the intermediate states
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question_str: str
|
||||
sub_question_search_queries: list[str]
|
||||
sub_question_nr: int
|
||||
# Answers sent back to core
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
# Answers sent back to core
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
@@ -1,46 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
|
||||
|
||||
def sub_continue_to_verifier(state: ResearchQAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes each de-douped retrieved doc to the verifier step - in parallel
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
|
||||
return [
|
||||
Send(
|
||||
"sub_verifier",
|
||||
VerifierState(
|
||||
document=doc,
|
||||
question=state["sub_question"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for doc in state["sub_question_base_retrieval_docs"]
|
||||
]
|
||||
|
||||
|
||||
def sub_continue_to_retrieval(
|
||||
state: ResearchQAState,
|
||||
) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_custom_retrieve",
|
||||
RetrieverState(
|
||||
rewritten_query=query,
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for query in state["sub_question_rewritten_queries"]
|
||||
]
|
||||
@@ -1,93 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_retrieval
|
||||
from danswer.agent_search.deep_qa_graph.edges import sub_continue_to_verifier
|
||||
from danswer.agent_search.deep_qa_graph.nodes.combine_retrieved_docs import (
|
||||
sub_combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.deep_qa_graph.nodes.custom_retrieve import sub_custom_retrieve
|
||||
from danswer.agent_search.deep_qa_graph.nodes.dummy import sub_dummy
|
||||
from danswer.agent_search.deep_qa_graph.nodes.final_format import sub_final_format
|
||||
from danswer.agent_search.deep_qa_graph.nodes.generate import sub_generate
|
||||
from danswer.agent_search.deep_qa_graph.nodes.qa_check import sub_qa_check
|
||||
from danswer.agent_search.deep_qa_graph.nodes.verifier import sub_verifier
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAOutputState
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
|
||||
|
||||
def build_deep_qa_graph() -> StateGraph:
|
||||
# Define the nodes we will cycle between
|
||||
sub_answers = StateGraph(state_schema=ResearchQAState, output=ResearchQAOutputState)
|
||||
|
||||
### Add Nodes ###
|
||||
|
||||
# Dummy node for initial processing
|
||||
sub_answers.add_node(node="sub_dummy", action=sub_dummy)
|
||||
|
||||
# The retrieval step
|
||||
sub_answers.add_node(node="sub_custom_retrieve", action=sub_custom_retrieve)
|
||||
|
||||
# The dedupe step
|
||||
sub_answers.add_node(
|
||||
node="sub_combine_retrieved_docs", action=sub_combine_retrieved_docs
|
||||
)
|
||||
|
||||
# Verifying retrieved information
|
||||
sub_answers.add_node(node="sub_verifier", action=sub_verifier)
|
||||
|
||||
# Generating the response
|
||||
sub_answers.add_node(node="sub_generate", action=sub_generate)
|
||||
|
||||
# Checking the quality of the answer
|
||||
sub_answers.add_node(node="sub_qa_check", action=sub_qa_check)
|
||||
|
||||
# Final formatting of the response
|
||||
sub_answers.add_node(node="sub_final_format", action=sub_final_format)
|
||||
|
||||
### Add Edges ###
|
||||
|
||||
# Generate multiple sub-questions
|
||||
sub_answers.add_edge(start_key=START, end_key="sub_rewrite")
|
||||
|
||||
# For each sub-question, perform a retrieval in parallel
|
||||
sub_answers.add_conditional_edges(
|
||||
source="sub_rewrite",
|
||||
path=sub_continue_to_retrieval,
|
||||
path_map=["sub_custom_retrieve"],
|
||||
)
|
||||
|
||||
# Combine the retrieved docs for each sub-question from the parallel retrievals
|
||||
sub_answers.add_edge(
|
||||
start_key="sub_custom_retrieve", end_key="sub_combine_retrieved_docs"
|
||||
)
|
||||
|
||||
# Go over all of the combined retrieved docs and verify them against the original question
|
||||
sub_answers.add_conditional_edges(
|
||||
source="sub_combine_retrieved_docs",
|
||||
path=sub_continue_to_verifier,
|
||||
path_map=["sub_verifier"],
|
||||
)
|
||||
|
||||
# Generate an answer for each verified retrieved doc
|
||||
sub_answers.add_edge(start_key="sub_verifier", end_key="sub_generate")
|
||||
|
||||
# Check the quality of the answer
|
||||
sub_answers.add_edge(start_key="sub_generate", end_key="sub_qa_check")
|
||||
|
||||
sub_answers.add_edge(start_key="sub_qa_check", end_key="sub_final_format")
|
||||
|
||||
sub_answers.add_edge(start_key="sub_final_format", end_key=END)
|
||||
|
||||
return sub_answers
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: add the actual question
|
||||
inputs = {"sub_question": "Whose music is kind of hard to easily enjoy?"}
|
||||
sub_answers_graph = build_deep_qa_graph()
|
||||
compiled_sub_answers = sub_answers_graph.compile()
|
||||
output = compiled_sub_answers.invoke(inputs)
|
||||
print("\nOUTPUT:")
|
||||
print(output)
|
||||
@@ -1,31 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_combine_retrieved_docs(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_question_base_retrieval_docs = state["sub_question_base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(sub_question_base_retrieval_docs)}")
|
||||
dedupe_docs = []
|
||||
for base_retrieval_doc in sub_question_base_retrieval_docs:
|
||||
if base_retrieval_doc not in dedupe_docs:
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
return {
|
||||
"sub_question_deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,33 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def sub_custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE SUB---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
documents: list[InferenceSection] = []
|
||||
|
||||
return {
|
||||
"sub_question_base_retrieval_docs": documents,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - custom_retrieve",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,21 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_dummy(state: BaseQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy step
|
||||
"""
|
||||
|
||||
print("---Sub Dummy---")
|
||||
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - dummy",
|
||||
node_start_time=datetime.now(),
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_final_format(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Create the final output for the QA subgraph
|
||||
"""
|
||||
|
||||
print("---SUB FINAL FORMAT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
return {
|
||||
# TODO: Type this
|
||||
"sub_qas": [
|
||||
{
|
||||
"sub_question": state["sub_question"],
|
||||
"sub_answer": state["sub_question_answer"],
|
||||
"sub_question_nr": state["sub_question_nr"],
|
||||
"sub_answer_check": state["sub_question_answer_check"],
|
||||
}
|
||||
],
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - final format",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,56 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_generate(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---SUB GENERATE---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question"]
|
||||
docs = state["sub_question_verified_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs for sub-question: {len(docs)}")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
if len(docs) > 0:
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
response_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
else:
|
||||
response_str = ""
|
||||
|
||||
return {
|
||||
"sub_question_answer": response_str,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.prompts import SUB_CHECK_PROMPT
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_qa_check(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the final output satisfies the original user question
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---CHECK SUB QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_answer = state["sub_question_answer"]
|
||||
sub_question = state["sub_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=SUB_CHECK_PROMPT.format(
|
||||
sub_question=sub_question, sub_answer=sub_answer
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"sub_question_answer_check": formatted_response.decision,
|
||||
"log_messages": generate_log_message(
|
||||
message=f"sub - qa check: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
def sub_rewrite(state: ResearchQAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
|
||||
print("---SUB TRANSFORM QUERY---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["sub_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
fast_llm: LLM = state["fast_llm"]
|
||||
llm_response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
rewritten_queries: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
|
||||
|
||||
print(f"rewritten_queries: {rewritten_queries}")
|
||||
|
||||
rewritten_queries = RewrittenQueries(
|
||||
rewritten_queries=[
|
||||
"music hard to listen to",
|
||||
"Music that is not fun or pleasant",
|
||||
]
|
||||
)
|
||||
|
||||
print(f"hardcoded rewritten_queries: {rewritten_queries}")
|
||||
|
||||
return {
|
||||
"sub_question_rewritten_queries": rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="sub - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---SUB VERIFY QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = list(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"core - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,13 +0,0 @@
|
||||
SUB_CHECK_PROMPT = """ \n
|
||||
Please check whether the suggested answer seems to address the original question.
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
@@ -1,64 +0,0 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
class ResearchQAState(TypedDict):
|
||||
# The 'core SubQuestion' state.
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question: str
|
||||
sub_question_nr: int
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
|
||||
|
||||
class ResearchQAOutputState(TypedDict):
|
||||
# The 'SubQuestion' output state. Removes all the intermediate states
|
||||
sub_question_rewritten_queries: list[str]
|
||||
sub_question: str
|
||||
sub_question_nr: int
|
||||
# Answers sent back to core
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_base_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_deduped_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_verified_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_reranked_retrieval_docs: Annotated[
|
||||
Sequence[InferenceSection], operator.add
|
||||
]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
sub_question_answer: str
|
||||
sub_question_answer_check: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
@@ -1,75 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
from typing import Union
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.core_qa_graph.states import BaseQAState
|
||||
from danswer.agent_search.deep_qa_graph.states import ResearchQAState
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
|
||||
|
||||
|
||||
def continue_to_initial_sub_questions(
|
||||
state: QAState,
|
||||
) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_answers_graph_initial",
|
||||
BaseQAState(
|
||||
sub_question_str=initial_sub_question["sub_question_str"],
|
||||
sub_question_search_queries=initial_sub_question[
|
||||
"sub_question_search_queries"
|
||||
],
|
||||
sub_question_nr=initial_sub_question["sub_question_nr"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
)
|
||||
for initial_sub_question in state["initial_sub_questions"]
|
||||
]
|
||||
|
||||
|
||||
def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
# Routes re-written queries to the (parallel) retrieval steps
|
||||
# Notice the 'Send()' API that takes care of the parallelization
|
||||
return [
|
||||
Send(
|
||||
"sub_answers_graph",
|
||||
ResearchQAState(
|
||||
sub_question=sub_question["sub_question_str"],
|
||||
sub_question_nr=sub_question["sub_question_nr"],
|
||||
graph_start_time=state["graph_start_time"],
|
||||
primary_llm=state["primary_llm"],
|
||||
fast_llm=state["fast_llm"],
|
||||
),
|
||||
)
|
||||
for sub_question in state["sub_questions"]
|
||||
]
|
||||
|
||||
|
||||
def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
print("---GO TO DEEP ANSWER OR END---")
|
||||
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
BASE_CHECK_MESSAGE = [
|
||||
HumanMessage(
|
||||
content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
|
||||
)
|
||||
]
|
||||
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(BASE_CHECK_MESSAGE)
|
||||
|
||||
print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
|
||||
|
||||
if response.pretty_repr() == "no":
|
||||
return "decompose"
|
||||
else:
|
||||
return "end"
|
||||
@@ -1,171 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.core_qa_graph.graph_builder import build_core_qa_graph
|
||||
from danswer.agent_search.deep_qa_graph.graph_builder import build_deep_qa_graph
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_answer_sub_questions
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_deep_answer
|
||||
from danswer.agent_search.primary_graph.edges import continue_to_initial_sub_questions
|
||||
from danswer.agent_search.primary_graph.nodes.base_wait import base_wait
|
||||
from danswer.agent_search.primary_graph.nodes.combine_retrieved_docs import (
|
||||
combine_retrieved_docs,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.custom_retrieve import custom_retrieve
|
||||
from danswer.agent_search.primary_graph.nodes.decompose import decompose
|
||||
from danswer.agent_search.primary_graph.nodes.deep_answer_generation import (
|
||||
deep_answer_generation,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.dummy_start import dummy_start
|
||||
from danswer.agent_search.primary_graph.nodes.entity_term_extraction import (
|
||||
entity_term_extraction,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.final_stuff import final_stuff
|
||||
from danswer.agent_search.primary_graph.nodes.generate_initial import generate_initial
|
||||
from danswer.agent_search.primary_graph.nodes.main_decomp_base import main_decomp_base
|
||||
from danswer.agent_search.primary_graph.nodes.rewrite import rewrite
|
||||
from danswer.agent_search.primary_graph.nodes.sub_qa_level_aggregator import (
|
||||
sub_qa_level_aggregator,
|
||||
)
|
||||
from danswer.agent_search.primary_graph.nodes.sub_qa_manager import sub_qa_manager
|
||||
from danswer.agent_search.primary_graph.nodes.verifier import verifier
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
|
||||
|
||||
def build_core_graph() -> StateGraph:
|
||||
# Define the nodes we will cycle between
|
||||
core_answer_graph = StateGraph(state_schema=QAState)
|
||||
|
||||
### Add Nodes ###
|
||||
core_answer_graph.add_node(node="dummy_start",
|
||||
action=dummy_start)
|
||||
|
||||
# Re-writing the question
|
||||
core_answer_graph.add_node(node="rewrite",
|
||||
action=rewrite)
|
||||
|
||||
# The retrieval step
|
||||
core_answer_graph.add_node(node="custom_retrieve",
|
||||
action=custom_retrieve)
|
||||
|
||||
# Combine and dedupe retrieved docs.
|
||||
core_answer_graph.add_node(
|
||||
node="combine_retrieved_docs",
|
||||
action=combine_retrieved_docs
|
||||
)
|
||||
|
||||
# Extract entities, terms and relationships
|
||||
core_answer_graph.add_node(
|
||||
node="entity_term_extraction",
|
||||
action=entity_term_extraction
|
||||
)
|
||||
|
||||
# Verifying that a retrieved doc is relevant
|
||||
core_answer_graph.add_node(node="verifier",
|
||||
action=verifier)
|
||||
|
||||
# Initial question decomposition
|
||||
core_answer_graph.add_node(node="main_decomp_base",
|
||||
action=main_decomp_base)
|
||||
|
||||
# Build the base QA sub-graph and compile it
|
||||
compiled_core_qa_graph = build_core_qa_graph().compile()
|
||||
# Add the compiled base QA sub-graph as a node to the core graph
|
||||
core_answer_graph.add_node(
|
||||
node="sub_answers_graph_initial",
|
||||
action=compiled_core_qa_graph
|
||||
)
|
||||
|
||||
# Checking whether the initial answer is in the ballpark
|
||||
core_answer_graph.add_node(node="base_wait",
|
||||
action=base_wait)
|
||||
|
||||
# Decompose the question into sub-questions
|
||||
core_answer_graph.add_node(node="decompose",
|
||||
action=decompose)
|
||||
|
||||
# Manage the sub-questions
|
||||
core_answer_graph.add_node(node="sub_qa_manager",
|
||||
action=sub_qa_manager)
|
||||
|
||||
# Build the research QA sub-graph and compile it
|
||||
compiled_deep_qa_graph = build_deep_qa_graph().compile()
|
||||
# Add the compiled research QA sub-graph as a node to the core graph
|
||||
core_answer_graph.add_node(node="sub_answers_graph",
|
||||
action=compiled_deep_qa_graph)
|
||||
|
||||
# Aggregate the sub-questions
|
||||
core_answer_graph.add_node(
|
||||
node="sub_qa_level_aggregator",
|
||||
action=sub_qa_level_aggregator
|
||||
)
|
||||
|
||||
# aggregate sub questions and answers
|
||||
core_answer_graph.add_node(
|
||||
node="deep_answer_generation",
|
||||
action=deep_answer_generation
|
||||
)
|
||||
|
||||
# A final clean-up step
|
||||
core_answer_graph.add_node(node="final_stuff",
|
||||
action=final_stuff)
|
||||
|
||||
# Generating a response after we know the documents are relevant
|
||||
core_answer_graph.add_node(node="generate_initial",
|
||||
action=generate_initial)
|
||||
|
||||
### Add Edges ###
|
||||
|
||||
# start the initial sub-question decomposition
|
||||
core_answer_graph.add_edge(start_key=START,
|
||||
end_key="main_decomp_base")
|
||||
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="main_decomp_base",
|
||||
path=continue_to_initial_sub_questions,
|
||||
)
|
||||
|
||||
# use the retrieved information to generate the answer
|
||||
core_answer_graph.add_edge(
|
||||
start_key=["verifier", "sub_answers_graph_initial"],
|
||||
end_key="generate_initial"
|
||||
)
|
||||
core_answer_graph.add_edge(start_key="generate_initial",
|
||||
end_key="base_wait")
|
||||
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="base_wait",
|
||||
path=continue_to_deep_answer,
|
||||
path_map={"decompose": "entity_term_extraction", "end": "final_stuff"},
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(start_key="entity_term_extraction", end_key="decompose")
|
||||
|
||||
core_answer_graph.add_edge(start_key="decompose",
|
||||
end_key="sub_qa_manager")
|
||||
core_answer_graph.add_conditional_edges(
|
||||
source="sub_qa_manager",
|
||||
path=continue_to_answer_sub_questions,
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="sub_answers_graph",
|
||||
end_key="sub_qa_level_aggregator"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="sub_qa_level_aggregator",
|
||||
end_key="deep_answer_generation"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(
|
||||
start_key="deep_answer_generation",
|
||||
end_key="final_stuff"
|
||||
)
|
||||
|
||||
core_answer_graph.add_edge(start_key="final_stuff",
|
||||
end_key=END)
|
||||
|
||||
core_answer_graph.compile()
|
||||
|
||||
return core_answer_graph
|
||||
@@ -1,27 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def base_wait(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Ensures that all required steps are completed before proceeding to the next step
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: {} (no operation, just logging)
|
||||
"""
|
||||
|
||||
print("---Base Wait ---")
|
||||
node_start_time = datetime.now()
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="core - base_wait",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def combine_retrieved_docs(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dedupe the retrieved docs.
|
||||
"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
base_retrieval_docs: Sequence[InferenceSection] = state["base_retrieval_docs"]
|
||||
|
||||
print(f"Number of docs from steps: {len(base_retrieval_docs)}")
|
||||
dedupe_docs: list[InferenceSection] = []
|
||||
for base_retrieval_doc in base_retrieval_docs:
|
||||
if not any(
|
||||
base_retrieval_doc.center_chunk.document_id == doc.center_chunk.document_id
|
||||
for doc in dedupe_docs
|
||||
):
|
||||
dedupe_docs.append(base_retrieval_doc)
|
||||
|
||||
print(f"Number of deduped docs: {len(dedupe_docs)}")
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": dedupe_docs,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - combine_retrieved_docs (dedupe)",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import RetrieverState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def custom_retrieve(state: RetrieverState) -> dict[str, Any]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
retriever_state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print("---RETRIEVE---")
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
query = state["rewritten_query"]
|
||||
|
||||
# Retrieval
|
||||
# TODO: add the actual retrieval, probably from search_tool.run()
|
||||
llm, fast_llm = get_default_llms()
|
||||
with get_session_context_manager() as db_session:
|
||||
top_sections = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
).reranked_sections
|
||||
print(len(top_sections))
|
||||
documents: list[InferenceSection] = []
|
||||
|
||||
return {
|
||||
"base_retrieval_docs": documents,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - custom_retrieve",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def decompose(state: QAState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_term_extraction_dict = state["retrieved_entities_relationships"][
|
||||
"retrieved_entities_relationships"
|
||||
]
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_term_extraction_dict
|
||||
)
|
||||
|
||||
initial_question_answers = state["initial_sub_qas"]
|
||||
|
||||
addressed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "yes"
|
||||
]
|
||||
failed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "no"
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT.format(
|
||||
question=question,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
sub_questions_dict = {}
|
||||
for sub_question_nr, sub_question_dict in enumerate(
|
||||
parsed_response["sub_questions"]
|
||||
):
|
||||
sub_question_dict["answered"] = False
|
||||
sub_question_dict["verified"] = False
|
||||
sub_questions_dict[sub_question_nr] = sub_question_dict
|
||||
|
||||
return {
|
||||
"decomposed_sub_questions_dict": sub_questions_dict,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - decompose",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,61 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
|
||||
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def deep_answer_generation(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---DEEP GENERATE---")
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
deep_answer_context = state["core_answer_dynamic_context"]
|
||||
|
||||
print(f"Number of verified retrieval docs - deep: {len(docs)}")
|
||||
|
||||
combined_context = normalize_whitespace(
|
||||
COMBINED_CONTEXT.format(
|
||||
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=MODIFIED_RAG_PROMPT.format(
|
||||
question=question, combined_context=combined_context
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
return {
|
||||
"deep_answer": response.content,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - deep answer generation",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,11 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
|
||||
|
||||
def dummy_start(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Dummy node to set the start time
|
||||
"""
|
||||
return {"start_time": datetime.now()}
|
||||
@@ -1,51 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import ENTITY_TERM_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
from danswer.llm.factory import get_default_llms
|
||||
|
||||
|
||||
def entity_term_extraction(state: QAState) -> dict[str, Any]:
|
||||
"""Extract entities and terms from the question and context"""
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
doc_context = format_docs(docs)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
_, fast_llm = get_default_llms()
|
||||
# Grader
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
# structured_response_format={"type": "json_object", "schema": RewrittenQueries.model_json_schema()},
|
||||
# structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
return {
|
||||
"retrieved_entities_relationships": parsed_response,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - entity term extraction",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,85 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def final_stuff(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Invokes the agent model to generate a response based on the current state. Given
|
||||
the question, it will decide to retrieve using the retriever tool, or simply end.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the agent response appended to messages
|
||||
"""
|
||||
print("---FINAL---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
messages = state["log_messages"]
|
||||
time_ordered_messages = [x.pretty_repr() for x in messages]
|
||||
time_ordered_messages.sort()
|
||||
|
||||
print("Message Log:")
|
||||
print("\n".join(time_ordered_messages))
|
||||
|
||||
initial_sub_qas = state["initial_sub_qas"]
|
||||
initial_sub_qa_list = []
|
||||
for initial_sub_qa in initial_sub_qas:
|
||||
if initial_sub_qa["sub_answer_check"] == "yes":
|
||||
initial_sub_qa_list.append(
|
||||
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
|
||||
|
||||
log_message = generate_log_message(
|
||||
message="all - final_stuff",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
)
|
||||
|
||||
print(log_message)
|
||||
print("--------------------------------")
|
||||
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
|
||||
print("--------------------------------")
|
||||
|
||||
if not state.get("deep_answer"):
|
||||
print("No Deep Answer was required")
|
||||
return {
|
||||
"log_messages": log_message,
|
||||
}
|
||||
|
||||
deep_answer = state["deep_answer"]
|
||||
sub_qas = state["sub_qas"]
|
||||
sub_qa_list = []
|
||||
for sub_qa in sub_qas:
|
||||
if sub_qa["sub_answer_check"] == "yes":
|
||||
sub_qa_list.append(
|
||||
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
sub_qa_context = "\n".join(sub_qa_list)
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Final Deep Answer:\n{deep_answer}")
|
||||
print("--------------------------------")
|
||||
print("Sub Questions and Answers:")
|
||||
print(sub_qa_context)
|
||||
|
||||
return {
|
||||
"log_messages": generate_log_message(
|
||||
message="all - final_stuff",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,52 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def generate(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=question, context=format_docs(docs))
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm = state["fast_llm"]
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=None,
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"base_answer": response[0].pretty_repr(),
|
||||
"log_messages": generate_log_message(
|
||||
message="core - generate",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,72 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import INITIAL_RAG_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def generate_initial(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---GENERATE INITIAL---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
print(f"Number of verified retrieval docs - base: {len(docs)}")
|
||||
|
||||
sub_question_answers = state["initial_sub_qas"]
|
||||
|
||||
sub_question_answers_list = []
|
||||
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE = """
|
||||
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
|
||||
"""
|
||||
for sub_question_answer_dict in sub_question_answers:
|
||||
if (
|
||||
sub_question_answer_dict["sub_answer_check"] == "yes"
|
||||
and len(sub_question_answer_dict["sub_answer"]) > 0
|
||||
and sub_question_answer_dict["sub_answer"] != "I don't know"
|
||||
):
|
||||
sub_question_answers_list.append(
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=sub_question_answer_dict["sub_question"],
|
||||
sub_answer=sub_question_answer_dict["sub_answer"],
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(sub_question_answers_list)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_PROMPT.format(
|
||||
question=question,
|
||||
context=format_docs(docs),
|
||||
answered_sub_questions=sub_question_answer_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
return {
|
||||
"base_answer": response.pretty_repr(),
|
||||
"log_messages": generate_log_message(
|
||||
message="core - generate initial",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.prompts import INITIAL_DECOMPOSITION_PROMPT
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def main_decomp_base(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Perform an initial question decomposition, incl. one search term
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with initial decomposition
|
||||
"""
|
||||
|
||||
print("---INITIAL DECOMP---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_DECOMPOSITION_PROMPT.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
content = response.pretty_repr()
|
||||
list_of_subquestions = clean_and_parse_list_string(content)
|
||||
|
||||
decomp_list = []
|
||||
|
||||
for sub_question_nr, sub_question in enumerate(list_of_subquestions):
|
||||
sub_question_str = sub_question["sub_question"].strip()
|
||||
# temporarily
|
||||
sub_question_search_queries = [sub_question["search_term"]]
|
||||
|
||||
decomp_list.append(
|
||||
{
|
||||
"sub_question_str": sub_question_str,
|
||||
"sub_question_search_queries": sub_question_search_queries,
|
||||
"sub_question_nr": sub_question_nr,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"initial_sub_questions": decomp_list,
|
||||
"sub_query_start_time": node_start_time,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - initial decomp",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,55 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def rewrite(state: QAState) -> dict[str, Any]:
|
||||
"""
|
||||
Transform the initial question into more suitable search queries.
|
||||
|
||||
Args:
|
||||
qa_state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---STARTING GRAPH---")
|
||||
graph_start_time = datetime.now()
|
||||
|
||||
print("---TRANSFORM QUERY---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
fast_llm = state["fast_llm"]
|
||||
llm_response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=RewrittenQueries.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
formatted_response: RewrittenQueries = json.loads(llm_response[0].pretty_repr())
|
||||
|
||||
return {
|
||||
"rewritten_queries": formatted_response.rewritten_queries,
|
||||
"log_messages": generate_log_message(
|
||||
message="core - rewrite",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=graph_start_time,
|
||||
),
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def sub_qa_level_aggregator(state: QAState) -> dict[str, Any]:
|
||||
sub_qas = state["sub_qas"]
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
dynamic_context_list = [
|
||||
"Below you will find useful information to answer the original question:"
|
||||
]
|
||||
checked_sub_qas = []
|
||||
|
||||
for core_answer_sub_qa in sub_qas:
|
||||
question = core_answer_sub_qa["sub_question"]
|
||||
answer = core_answer_sub_qa["sub_answer"]
|
||||
verified = core_answer_sub_qa["sub_answer_check"]
|
||||
|
||||
if verified == "yes":
|
||||
dynamic_context_list.append(
|
||||
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
|
||||
)
|
||||
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
|
||||
dynamic_context = "\n".join(dynamic_context_list)
|
||||
|
||||
return {
|
||||
"core_answer_dynamic_context": dynamic_context,
|
||||
"checked_sub_qas": checked_sub_qas,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - sub qa level aggregator",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.primary_graph.states import QAState
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def sub_qa_manager(state: QAState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
sub_questions_dict = state["decomposed_sub_questions_dict"]
|
||||
|
||||
sub_questions = {}
|
||||
|
||||
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
|
||||
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
|
||||
|
||||
return {
|
||||
"sub_questions": sub_questions,
|
||||
"num_new_question_iterations": 0,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - sub qa manager",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,59 +0,0 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.primary_graph.states import VerifierState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def verifier(state: VerifierState) -> dict[str, Any]:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print("---VERIFY QUTPUT---")
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["question"]
|
||||
document_content = state["document"].combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=question, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
llm = state["fast_llm"]
|
||||
response = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
structured_response_format=BinaryDecision.model_json_schema(),
|
||||
)
|
||||
)
|
||||
|
||||
raw_response = json.loads(response[0].pretty_repr())
|
||||
formatted_response = BinaryDecision.model_validate(raw_response)
|
||||
|
||||
return {
|
||||
"deduped_retrieval_docs": [state["document"]]
|
||||
if formatted_response.decision == "yes"
|
||||
else [],
|
||||
"log_messages": generate_log_message(
|
||||
message=f"core - verifier: {formatted_response.decision}",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,86 +0,0 @@
|
||||
INITIAL_DECOMPOSITION_PROMPT = """ \n
|
||||
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
For each sub-question, please also create one search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of json objects with the following format:
|
||||
|
||||
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
INITIAL_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of answered sub-questions - these are very important(!) and definitely should be
|
||||
considered to answer the question.
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
If you don't know the answer or if the provided information is empty or insufficient, just say
|
||||
"I don't know". Do not use your internal knowledge!
|
||||
|
||||
Again, only use the provided informationand do not use your internal knowledge! It is a matter of life
|
||||
and death that you do NOT use your internal knowledge, just the provided information!
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
And here is the question and the provided information:
|
||||
\n
|
||||
\nQuestion:\n {question}
|
||||
|
||||
\nAnswered Sub-questions:\n {answered_sub_questions}
|
||||
|
||||
\nContext:\n {context} \n\n
|
||||
\n\n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ENTITY_TERM_PROMPT = """ \n
|
||||
Based on the original question and the context retieved from a dataset, please generate a list of
|
||||
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
|
||||
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
|
||||
|
||||
\n\n
|
||||
Here is the original question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
And here is the context retrieved:
|
||||
\n ------- \n
|
||||
{context}
|
||||
\n ------- \n
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
|
||||
{{"retrieved_entities_relationships": {{
|
||||
"entities": [{{
|
||||
"entity_name": <assign a name for the entity>,
|
||||
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
|
||||
}}],
|
||||
"relationships": [{{
|
||||
"name": <assign a name for the relationship>,
|
||||
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
|
||||
"entities": [<related entity name 1>, <related entity name 2>]
|
||||
}}],
|
||||
"terms": [{{
|
||||
"term_name": <assign a name for the term>,
|
||||
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
|
||||
"similar_to": <list terms that are similar to this term>
|
||||
}}]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
@@ -1,73 +0,0 @@
|
||||
import operator
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langgraph.graph.message import add_messages
|
||||
|
||||
from danswer.agent_search.shared_graph_utils.models import RewrittenQueries
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class QAState(TypedDict):
|
||||
# The 'main' state of the answer graph
|
||||
original_question: str
|
||||
graph_start_time: datetime
|
||||
# start time for parallel initial sub-questionn thread
|
||||
sub_query_start_time: datetime
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
rewritten_queries: RewrittenQueries
|
||||
sub_questions: list[dict]
|
||||
initial_sub_questions: list[dict]
|
||||
ranked_subquestion_ids: list[int]
|
||||
decomposed_sub_questions_dict: dict
|
||||
rejected_sub_questions: Annotated[list[str], operator.add]
|
||||
rejected_sub_questions_handled: bool
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
checked_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
base_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
deduped_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
retrieved_entities_relationships: dict
|
||||
questions_context: list[dict]
|
||||
qa_level: int
|
||||
top_chunks: list[InferenceSection]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
num_new_question_iterations: int
|
||||
core_answer_dynamic_context: str
|
||||
dynamic_context: str
|
||||
initial_base_answer: str
|
||||
base_answer: str
|
||||
deep_answer: str
|
||||
|
||||
|
||||
class QAOuputState(TypedDict):
|
||||
# The 'main' output state of the answer graph. Removes all the intermediate states
|
||||
original_question: str
|
||||
log_messages: Annotated[Sequence[BaseMessage], add_messages]
|
||||
sub_questions: list[dict]
|
||||
sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
initial_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
checked_sub_qas: Annotated[Sequence[dict], operator.add]
|
||||
reranked_retrieval_docs: Annotated[Sequence[InferenceSection], operator.add]
|
||||
retrieved_entities_relationships: dict
|
||||
top_chunks: list[InferenceSection]
|
||||
sub_question_top_chunks: Annotated[Sequence[dict], operator.add]
|
||||
base_answer: str
|
||||
deep_answer: str
|
||||
|
||||
|
||||
class RetrieverState(TypedDict):
|
||||
# The state for the parallel Retrievers. They each need to see only one query
|
||||
rewritten_query: str
|
||||
graph_start_time: datetime
|
||||
|
||||
|
||||
class VerifierState(TypedDict):
|
||||
# The state for the parallel verification step. Each node execution need to see only one question/doc pair
|
||||
document: InferenceSection
|
||||
question: str
|
||||
graph_start_time: datetime
|
||||
@@ -1,22 +0,0 @@
|
||||
from danswer.agent_search.primary_graph.graph_builder import build_core_graph
|
||||
from danswer.llm.answering.answer import AnswerStream
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def run_graph(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
tools: list[Tool],
|
||||
) -> AnswerStream:
|
||||
graph = build_core_graph()
|
||||
|
||||
inputs = {
|
||||
"original_question": query,
|
||||
"messages": [],
|
||||
"tools": tools,
|
||||
"llm": llm,
|
||||
}
|
||||
compiled_graph = graph.compile()
|
||||
output = compiled_graph.invoke(input=inputs)
|
||||
yield from output
|
||||
@@ -1,16 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
|
||||
|
||||
class SubQuestions(BaseModel):
|
||||
sub_questions: list[str]
|
||||
@@ -1,342 +0,0 @@
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
|
||||
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
|
||||
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
|
||||
enabling the system to search more broadly.
|
||||
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
|
||||
|
||||
|
||||
REWRITE_PROMPT_MULTI = """ \n
|
||||
Please create a list of 2-3 sample documents that could answer an original question. Each document
|
||||
should be about as long as the original question. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
|
||||
|
||||
BASE_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the context provided below - and only the
|
||||
provided context - to answer the question. If you don't know the answer or if the provided context is
|
||||
empty, just say "I don't know". Do not use your internal knowledge!
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n
|
||||
\n\n
|
||||
Answer:"""
|
||||
|
||||
BASE_CHECK_PROMPT = """ \n
|
||||
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
|
||||
original question requests a simple, factual answer, and there are no ambiguities, judgements,
|
||||
aggregations, or any other complications that may require extra context. (I.e., if the question is
|
||||
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
VERIFIER_PROMPT = """ \n
|
||||
Please check whether the document seems to be relevant for the answer of the original question. Please
|
||||
only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the document text:
|
||||
\n ------- \n
|
||||
{document_content}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
|
||||
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of subquestions:
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_SINGLE = """ \n
|
||||
Please convert an initial user question into a more appropriate search query for retrievel from a
|
||||
document store. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the query: """
|
||||
|
||||
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
|
||||
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
\nQuestion: {question}
|
||||
\nContext: {combined_context} \n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please also provide a search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of json dictionaries with the following format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DECOMPOSE_PROMPT = """ \n
|
||||
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
|
||||
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question.
|
||||
|
||||
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
|
||||
document store, expressed as entities, relationships and terms. You can also think about the types
|
||||
mentioned in brackets
|
||||
|
||||
Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
and or resolve ambiguities
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Don't be too specific unless the original question is specific.
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
#### Consolidations
|
||||
COMBINED_CONTEXT = """-------
|
||||
Below you will find useful information to answer the original question. First, you see a number of
|
||||
sub-questions with their answers. This information should be considered to be more focussed and
|
||||
somewhat more specific to the original question as it tries to contextualized facts.
|
||||
After that will see the documents that were considered to be relevant to answer the original question.
|
||||
|
||||
Here are the sub-questions and their answers:
|
||||
\n\n {deep_answer_context} \n\n
|
||||
\n\n Here are the documents that were considered to be relevant to answer the original question:
|
||||
\n\n {formated_docs} \n\n
|
||||
----------------
|
||||
"""
|
||||
|
||||
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
|
||||
Below you will find a question that we ultimately want to answer (the original question) and a list of
|
||||
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
|
||||
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
|
||||
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
|
||||
motivation and 2 is less relevant.)
|
||||
|
||||
Please rank the motivations in order of relevance for answering the original question. Also, try to
|
||||
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
|
||||
Ultimately, create a list with the motivation numbers where the number of the most relevant
|
||||
motivations comes first.
|
||||
|
||||
Here is the original question:
|
||||
\n\n {original_question} \n\n
|
||||
\n\n Here is the list of sub-question motivations:
|
||||
\n\n {sub_question_explanations} \n\n
|
||||
----------------
|
||||
|
||||
Please think step by step and then generate the ranked list of motivations.
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
{{"reasonning": <explain your reasoning for the ranking>,
|
||||
"ranked_motivations": <ranked list of motivation numbers>}}
|
||||
"""
|
||||
@@ -1,91 +0,0 @@
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
return "\n\n".join(doc.combined_content for doc in docs)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return ast.literal_eval(cleaned_string)
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
|
||||
entities = entity_term_extraction_dict["entities"]
|
||||
terms = entity_term_extraction_dict["terms"]
|
||||
relationships = entity_term_extraction_dict["relationships"]
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def generate_log_message(
|
||||
message: str,
|
||||
node_start_time: datetime,
|
||||
graph_start_time: datetime | None = None,
|
||||
) -> str:
|
||||
current_time = datetime.now()
|
||||
|
||||
if graph_start_time is not None:
|
||||
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
||||
else:
|
||||
graph_time_str = "N/A"
|
||||
|
||||
node_time_str = _format_time_delta(current_time - node_start_time)
|
||||
|
||||
return f"{graph_time_str} ({node_time_str} s): {message}"
|
||||
@@ -23,9 +23,7 @@ def load_no_auth_user_preferences(store: KeyValueStore) -> UserPreferences:
|
||||
)
|
||||
return UserPreferences(**preferences_data)
|
||||
except KvKeyNotFoundError:
|
||||
return UserPreferences(
|
||||
chosen_assistants=None, default_model=None, auto_scroll=True
|
||||
)
|
||||
return UserPreferences(chosen_assistants=None, default_model=None)
|
||||
|
||||
|
||||
def fetch_no_auth_user(store: KeyValueStore) -> UserInfo:
|
||||
|
||||
@@ -87,7 +87,6 @@ from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -100,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
|
||||
@@ -11,7 +11,6 @@ from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -333,16 +332,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock: RedisLock = sender.primary_worker_lock
|
||||
lock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception:
|
||||
logger.exception("Failed to release primary worker lock")
|
||||
except Exception:
|
||||
logger.exception("Failed to check if primary worker lock is owned")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -11,7 +11,6 @@ from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -39,6 +38,7 @@ from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
@@ -116,13 +116,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
|
||||
# set thread_local=False since we don't control what thread the periodic task might
|
||||
# reacquire the lock with
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
@@ -231,7 +227,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
lock = worker.primary_worker_lock
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
|
||||
@@ -2,55 +2,54 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"task": "check_for_external_group_sync",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -29,7 +28,7 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
|
||||
@@ -18,11 +18,9 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import upsert_document_by_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -84,7 +82,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
name="check_for_doc_permissions_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -166,7 +164,7 @@ def try_creating_permissions_sync_task(
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -193,7 +191,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
name="connector_permission_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -263,12 +261,7 @@ def connector_permission_sync_generator_task(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
self.app, lock, document_external_accesses, source_type
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
@@ -293,7 +286,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
name="update_external_document_permissions_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
@@ -304,8 +297,6 @@ def update_external_document_permissions_task(
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
@@ -314,28 +305,18 @@ def update_external_document_permissions_task(
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
# Then we build the update requests to update vespa
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
# Then we upsert the document's external permissions in postgres
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
# If a new document was created, we associate it with the cc_pair
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -32,14 +31,10 @@ from danswer.redis.redis_connector_ext_group_sync import (
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -90,7 +85,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
name="check_for_external_group_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -111,22 +106,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session, source, only_sync=True
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.id != cc_pair_to_remove.id
|
||||
]
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
@@ -182,7 +161,7 @@ def try_creating_external_group_sync_task(
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -212,7 +191,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
name="connector_external_group_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
|
||||
@@ -23,7 +23,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
@@ -157,7 +156,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
@@ -487,7 +486,7 @@ def try_creating_indexing_task(
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -525,10 +524,7 @@ def try_creating_indexing_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
|
||||
)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
@@ -584,64 +580,39 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
"Indexing proxy - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
finally:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
continue
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -789,12 +760,9 @@ def connector_indexing_task(
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -13,13 +13,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
|
||||
@@ -8,7 +8,6 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -21,7 +20,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
@@ -77,7 +75,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -186,7 +184,7 @@ def try_creating_prune_generator_task(
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -211,7 +209,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
name="connector_pruning_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -240,12 +238,9 @@ def connector_pruning_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -9,7 +9,6 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
@@ -32,7 +31,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
@@ -81,7 +80,7 @@ logger = setup_logger()
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -655,28 +654,24 @@ def monitor_ccpair_indexing_taskset(
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
result_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
result_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"result_state={result_state} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
@@ -708,7 +703,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
@@ -819,7 +814,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
|
||||
@@ -2,79 +2,20 @@ import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
prompt: Prompt | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
)
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
@@ -90,49 +31,9 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
match_highlights=inference_section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -295,71 +196,3 @@ def extract_headers(
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
@@ -4,14 +4,12 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.context.search.models import SearchResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
@@ -27,7 +25,6 @@ class LlmDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
@@ -120,6 +117,20 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
@@ -135,20 +146,14 @@ class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class ChatDanswerBotResponse(BaseModel):
|
||||
answer: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
@@ -160,41 +165,9 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
|
||||
@@ -7,13 +7,10 @@ from typing import cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import create_temporary_persona
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
@@ -105,7 +102,6 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -117,10 +113,8 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -262,7 +256,6 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| DanswerContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -293,8 +286,6 @@ def stream_chat_message_objects(
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
include_contexts: bool = False,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -331,31 +322,17 @@ def stream_chat_message_objects(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
|
||||
@@ -578,34 +555,19 @@ def stream_chat_message_objects(
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
prompt_config = PromptConfig(
|
||||
system_prompt=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].system_prompt,
|
||||
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].datetime_aware,
|
||||
include_citations=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].include_citations,
|
||||
)
|
||||
elif prompt_override:
|
||||
if not final_msg.prompt:
|
||||
raise ValueError(
|
||||
"Prompt override cannot be applied, no base prompt found."
|
||||
)
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=prompt_override,
|
||||
)
|
||||
elif final_msg.prompt:
|
||||
prompt_config = PromptConfig.from_model(final_msg.prompt)
|
||||
else:
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
@@ -625,13 +587,11 @@ def stream_chat_message_objects(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
bypass_acl=bypass_acl,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
@@ -645,7 +605,6 @@ def stream_chat_message_objects(
|
||||
additional_headers=custom_tool_additional_headers,
|
||||
),
|
||||
)
|
||||
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
@@ -777,8 +736,6 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(DanswerContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
@@ -887,30 +844,3 @@ def stream_chat_message(
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_slack(
|
||||
packets: ChatPacketStream,
|
||||
) -> ChatDanswerBotResponse:
|
||||
response = ChatDanswerBotResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
|
||||
@@ -308,22 +308,6 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
# Due to breakages in the confluence API, the timezone offset must be specified client side
|
||||
# to match the user's specified timezone.
|
||||
|
||||
# The current state of affairs:
|
||||
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
|
||||
# no API retrieves the user's timezone
|
||||
# All data is returned in UTC, so we can't derive the user's timezone from that
|
||||
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -522,6 +506,3 @@ API_KEY_HASH_ROUNDS = (
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
@@ -31,8 +31,6 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -261,32 +259,6 @@ class DanswerCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
class DanswerCeleryTask:
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
|
||||
"update_external_document_permissions_task"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
@@ -4,8 +4,11 @@ import os
|
||||
# Danswer Slack Bot Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||
)
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
@@ -44,6 +47,17 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
# Set/unset by "Hide Non Answers"
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
# if set, will default DanswerBot to use quotes and reference documents
|
||||
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
|
||||
|
||||
# Maximum Questions Per Minute, Default Uncapped
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
|
||||
|
||||
@@ -70,9 +70,7 @@ GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int(
|
||||
)
|
||||
|
||||
# Typically, GenAI models nowadays are at least 4K tokens
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int(
|
||||
os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 4096
|
||||
)
|
||||
GEN_AI_MODEL_FALLBACK_MAX_TOKENS = 4096
|
||||
|
||||
# Number of tokens from chat history to include at maximum
|
||||
# 3000 should be enough context regardless of use, no need to include as much as possible
|
||||
|
||||
@@ -11,16 +11,11 @@ Connectors come in 3 different flows:
|
||||
- Load Connector:
|
||||
- Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all
|
||||
documents via a connector's API or loads the documents from some sort of a dump file.
|
||||
- Poll Connector:
|
||||
- Poll connector:
|
||||
- Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest
|
||||
changes and additions since the last round of polling. This connector helps keep the document index up to date
|
||||
without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of
|
||||
documents.
|
||||
- Slim Connector:
|
||||
- This connector should be a lighter weight method of checking all documents in the source to see if they still exist.
|
||||
- This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves.
|
||||
- This is used by our pruning job which removes old documents from the index.
|
||||
- The optional start and end datetimes can be ignored.
|
||||
- Event Based connectors:
|
||||
- Connectors that listen to events and update documents accordingly.
|
||||
- Currently not used by the background job, this exists for future design purposes.
|
||||
@@ -31,14 +26,8 @@ Refer to [interfaces.py](https://github.com/danswer-ai/danswer/blob/main/backend
|
||||
and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown):
|
||||
[Reference Pull Request](https://github.com/danswer-ai/danswer/pull/139)
|
||||
|
||||
For implementing a Slim Connector, refer to the comments in this PR:
|
||||
[Slim Connector PR](https://github.com/danswer-ai/danswer/pull/3303/files)
|
||||
|
||||
All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector.
|
||||
|
||||
|
||||
#### Implementing the new Connector
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, SlimConnector, or EventConnector.
|
||||
The connector must subclass one or more of LoadConnector, PollConnector, or EventConnector.
|
||||
|
||||
The `__init__` should take arguments for configuring what documents the connector will and where it finds those
|
||||
documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -71,7 +69,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
@@ -107,8 +104,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -209,14 +204,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
@@ -249,10 +242,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
|
||||
@@ -134,32 +134,6 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@@ -332,13 +306,6 @@ def _validate_connector_configuration(
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
|
||||
@@ -32,11 +32,7 @@ def get_user_email_from_username__server(
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
# For now, we'll just return a string that indicates failure
|
||||
# We may want to revert to returning None in the future
|
||||
# email = None
|
||||
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
@@ -12,15 +12,12 @@ from dateutil import parser
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.interfaces import SlimConnector
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.connectors.models import SlimDocument
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -31,8 +28,6 @@ logger = setup_logger()
|
||||
SLAB_GRAPHQL_MAX_TRIES = 10
|
||||
SLAB_API_URL = "https://api.slab.com/v1/graphql"
|
||||
|
||||
_SLIM_BATCH_SIZE = 1000
|
||||
|
||||
|
||||
def run_graphql_request(
|
||||
graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES
|
||||
@@ -163,26 +158,21 @@ def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str:
|
||||
return urljoin(urljoin(base_url, "posts/"), url_id)
|
||||
|
||||
|
||||
class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class SlabConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
slab_bot_token: str | None = None,
|
||||
) -> None:
|
||||
self.base_url = base_url
|
||||
self.batch_size = batch_size
|
||||
self._slab_bot_token: str | None = None
|
||||
self.slab_bot_token = slab_bot_token
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self._slab_bot_token = credentials["slab_bot_token"]
|
||||
self.slab_bot_token = credentials["slab_bot_token"]
|
||||
return None
|
||||
|
||||
@property
|
||||
def slab_bot_token(self) -> str:
|
||||
if self._slab_bot_token is None:
|
||||
raise ConnectorMissingCredentialError("Slab")
|
||||
return self._slab_bot_token
|
||||
|
||||
def _iterate_posts(
|
||||
self, time_filter: Callable[[datetime], bool] | None = None
|
||||
) -> GenerateDocumentsOutput:
|
||||
@@ -237,21 +227,3 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
yield from self._iterate_posts(
|
||||
time_filter=lambda t: start_time <= t <= end_time
|
||||
)
|
||||
|
||||
def retrieve_all_slim_documents(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
end: SecondsSinceUnixEpoch | None = None,
|
||||
) -> GenerateSlimDocumentOutput:
|
||||
slim_doc_batch: list[SlimDocument] = []
|
||||
for post_id in get_all_post_ids(self.slab_bot_token):
|
||||
slim_doc_batch.append(
|
||||
SlimDocument(
|
||||
id=post_id,
|
||||
)
|
||||
)
|
||||
if len(slim_doc_batch) >= _SLIM_BATCH_SIZE:
|
||||
yield slim_doc_batch
|
||||
slim_doc_batch = []
|
||||
if slim_doc_batch:
|
||||
yield slim_doc_batch
|
||||
|
||||
@@ -16,7 +16,7 @@ from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.blocks.basic_components import MarkdownTextObject
|
||||
from slack_sdk.models.blocks.block_elements import ImageElement
|
||||
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -40,7 +40,10 @@ from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.db.chat import get_chat_session_by_message_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.text_processing import decode_escapes
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
_MAX_BLURB_LEN = 45
|
||||
|
||||
@@ -324,7 +327,7 @@ def _build_sources_blocks(
|
||||
|
||||
|
||||
def _priority_ordered_documents_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
@@ -347,7 +350,7 @@ def _priority_ordered_documents_blocks(
|
||||
|
||||
|
||||
def _build_citations_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
@@ -366,8 +369,51 @@ def _build_citations_blocks(
|
||||
return citations_block
|
||||
|
||||
|
||||
def _build_quotes_block(
|
||||
quotes: list[DanswerQuote],
|
||||
) -> list[Block]:
|
||||
quote_lines: list[str] = []
|
||||
doc_to_quotes: dict[str, list[str]] = {}
|
||||
doc_to_link: dict[str, str] = {}
|
||||
doc_to_sem_id: dict[str, str] = {}
|
||||
for q in quotes:
|
||||
quote = q.quote
|
||||
doc_id = q.document_id
|
||||
doc_link = q.link
|
||||
doc_name = q.semantic_identifier
|
||||
if doc_link and doc_name and doc_id and quote:
|
||||
if doc_id not in doc_to_quotes:
|
||||
doc_to_quotes[doc_id] = [quote]
|
||||
doc_to_link[doc_id] = doc_link
|
||||
doc_to_sem_id[doc_id] = (
|
||||
doc_name
|
||||
if q.source_type != DocumentSource.SLACK.value
|
||||
else "#" + doc_name
|
||||
)
|
||||
else:
|
||||
doc_to_quotes[doc_id].append(quote)
|
||||
|
||||
for doc_id, quote_strs in doc_to_quotes.items():
|
||||
quotes_str_clean = [
|
||||
replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs
|
||||
]
|
||||
longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5]
|
||||
single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes])
|
||||
link = doc_to_link[doc_id]
|
||||
sem_id = doc_to_sem_id[doc_id]
|
||||
quote_lines.append(
|
||||
f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}"
|
||||
)
|
||||
|
||||
if not doc_to_quotes:
|
||||
return []
|
||||
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def _build_qa_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
skip_quotes: bool = False,
|
||||
process_message_for_citations: bool = False,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
@@ -376,10 +422,13 @@ def _build_qa_response_blocks(
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
quotes = answer.quotes.quotes if answer.quotes else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
filter_block: Block | None = None
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
@@ -422,6 +471,16 @@ def _build_qa_response_blocks(
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
if quotes:
|
||||
quotes_blocks = _build_quotes_block(quotes)
|
||||
|
||||
# if no quotes OR `_build_quotes_block()` did not give back any blocks
|
||||
if not quotes_blocks:
|
||||
quotes_blocks = [
|
||||
SectionBlock(
|
||||
text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔"
|
||||
)
|
||||
]
|
||||
|
||||
response_blocks: list[Block] = []
|
||||
|
||||
@@ -430,6 +489,9 @@ def _build_qa_response_blocks(
|
||||
|
||||
response_blocks.extend(answer_blocks)
|
||||
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
|
||||
return response_blocks
|
||||
|
||||
|
||||
@@ -505,9 +567,10 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
answer: OneShotQAResponse,
|
||||
persona: Persona | None,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
@@ -524,6 +587,7 @@ def build_slack_response_blocks(
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
@@ -553,7 +617,8 @@ def build_slack_response_blocks(
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if use_citations and answer.citations:
|
||||
if use_citations:
|
||||
# if citations are enabled, only show cited documents
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
@@ -572,5 +637,4 @@ def build_slack_response_blocks(
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
)
|
||||
|
||||
return all_blocks
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -8,36 +9,46 @@ from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.chat.chat_utils import prepare_chat_message_request
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.process_message import gather_stream_for_slack
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.handlers.utils import slackify_message_thread
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
@@ -72,14 +83,16 @@ def handle_regular_answer(
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> bool:
|
||||
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
|
||||
|
||||
messages = message_info.thread_messages
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
user = None
|
||||
@@ -89,18 +102,9 @@ def handle_regular_answer(
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
prompt = None
|
||||
# If no persona is specified, use the default search based persona
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
if not persona:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
else:
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
@@ -108,26 +112,6 @@ def handle_regular_answer(
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# TODO: Add in support for Slack to truncate messages based on max LLM context
|
||||
# llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=llm.config.model_name,
|
||||
# provider_type=llm.config.model_provider,
|
||||
# )
|
||||
|
||||
# # In cases of threads, split the available tokens between docs and thread context
|
||||
# input_tokens = get_max_input_tokens(
|
||||
# model_name=llm.config.model_name,
|
||||
# model_provider=llm.config.model_provider,
|
||||
# )
|
||||
# max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
# combined_message = combine_message_thread(
|
||||
# messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer
|
||||
# )
|
||||
|
||||
combined_message = slackify_message_thread(messages)
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
slack_channel_config
|
||||
@@ -138,6 +122,13 @@ def handle_regular_answer(
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if slack_channel_config is None
|
||||
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/danswer" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
@@ -150,23 +141,75 @@ def handle_regular_answer(
|
||||
backoff=2,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, danswer_user: User | None
|
||||
) -> ChatDanswerBotResponse:
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=danswer_user,
|
||||
if len(new_message_request.messages) > 1:
|
||||
if new_message_request.persona_config:
|
||||
raise RuntimeError("Slack bot does not support persona config")
|
||||
elif new_message_request.persona_id is not None:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No persona id provided, this should never happen."
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
answer = gather_stream_for_slack(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return answer
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
@@ -196,24 +239,26 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
user=user,
|
||||
persona_id=persona.id,
|
||||
# This is not used in the Slack flow, only in the answer API
|
||||
persona_override_config=None,
|
||||
prompt=prompt,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
saved_search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
multilingual_query_expansion=saved_search_settings.multilingual_expansion
|
||||
if saved_search_settings
|
||||
else None,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
|
||||
if saved_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
answer = _get_slack_answer(
|
||||
new_message_request=answer_request, danswer_user=user
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -314,7 +359,7 @@ def handle_regular_answer(
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{combined_message}' - no documents found"
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -335,18 +380,18 @@ def handle_regular_answer(
|
||||
)
|
||||
return True
|
||||
|
||||
only_respond_if_citations = (
|
||||
only_respond_with_citations_or_quotes = (
|
||||
channel_conf
|
||||
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
|
||||
)
|
||||
|
||||
has_citations_or_quotes = bool(answer.citations or answer.quotes)
|
||||
if (
|
||||
only_respond_if_citations
|
||||
and not answer.citations
|
||||
only_respond_with_citations_or_quotes
|
||||
and not has_citations_or_quotes
|
||||
and not message_info.bypass_filters
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -364,8 +409,9 @@ def handle_regular_answer(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
persona=persona,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True, # No longer supporting quotes
|
||||
use_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,33 +1,8 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
# Note: this does not handle extremely long threads, every message will be included
|
||||
# with weaker LLMs, this could cause issues with exceeeding the token limit
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"AI said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
|
||||
@@ -19,8 +19,6 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.app_configs import DEV_MODE
|
||||
from danswer.configs.app_configs import POD_NAME
|
||||
from danswer.configs.app_configs import POD_NAMESPACE
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@@ -76,6 +74,7 @@ from danswer.db.slack_bot import fetch_slack_bots
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -251,7 +250,7 @@ class SlackbotHandler:
|
||||
nx=True,
|
||||
ex=TENANT_LOCK_EXPIRATION,
|
||||
)
|
||||
if not acquired and not DEV_MODE:
|
||||
if not acquired:
|
||||
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
|
||||
@@ -30,13 +30,13 @@ from danswer.configs.danswerbot_configs import (
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.danswerbot.slack.constants import FeedbackVisibility
|
||||
from danswer.danswerbot.slack.models import ThreadMessage
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
|
||||
@@ -145,10 +145,16 @@ def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
only_one_shot: bool = False,
|
||||
limit: int = 50,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if only_one_shot:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
@@ -220,11 +226,12 @@ def delete_messages_and_files_from_chat_session(
|
||||
|
||||
def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str | None,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
slack_thread_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
@@ -234,6 +241,7 @@ def create_chat_session(
|
||||
description=description,
|
||||
llm_override=llm_override,
|
||||
prompt_override=prompt_override,
|
||||
one_shot=one_shot,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
@@ -279,6 +287,8 @@ def duplicate_chat_session_for_user_from_slack(
|
||||
description="",
|
||||
llm_override=chat_session.llm_override,
|
||||
prompt_override=chat_session.prompt_override,
|
||||
# Chat sessions from Slack should put people in the chat UI, not the search
|
||||
one_shot=False,
|
||||
# Chat is in UI now so this is false
|
||||
danswerbot_flow=False,
|
||||
# Maybe we want this in the future to track if it was created from Slack
|
||||
|
||||
@@ -37,7 +37,6 @@ from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -427,9 +426,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(
|
||||
detail="User must authenticate",
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum as PyEnum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
@@ -125,7 +126,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
@@ -963,8 +963,9 @@ class ChatSession(Base):
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# This chat created by DanswerBot
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
danswerbot_flow: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# Only ever set to True if system is set to not hard-delete chats
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -1486,6 +1487,11 @@ class ChannelConfig(TypedDict):
|
||||
show_continue_in_web_ui: NotRequired[bool] # defaults to False
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
|
||||
|
||||
class SlackChannelConfig(Base):
|
||||
__tablename__ = "slack_channel_config"
|
||||
|
||||
@@ -1498,6 +1504,9 @@ class SlackChannelConfig(Base):
|
||||
channel_config: Mapped[ChannelConfig] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False
|
||||
)
|
||||
response_type: Mapped[SlackBotResponseType] = mapped_column(
|
||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
|
||||
@@ -185,7 +185,7 @@ def create_update_persona(
|
||||
"persona_id": persona_id,
|
||||
"user": user,
|
||||
"db_session": db_session,
|
||||
**create_persona_request.model_dump(exclude={"users", "groups"}),
|
||||
**create_persona_request.dict(exclude={"users", "groups"}),
|
||||
}
|
||||
|
||||
persona = upsert_persona(**persona_data)
|
||||
@@ -415,6 +415,9 @@ def upsert_prompt(
|
||||
return prompt
|
||||
|
||||
|
||||
# NOTE: This operation cannot update persona configuration options that
|
||||
# are core to the persona, such as its display priority and
|
||||
# whether or not the assistant is a built-in / default assistant
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -446,12 +449,6 @@ def upsert_persona(
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
"""
|
||||
NOTE: This operation cannot update persona configuration options that
|
||||
are core to the persona, such as its display priority and
|
||||
whether or not the assistant is a built-in / default assistant
|
||||
"""
|
||||
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
else:
|
||||
@@ -489,8 +486,6 @@ def upsert_persona(
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
# Built-in personas can only be updated through YAML configuration.
|
||||
# This ensures that core system personas are not modified unintentionally.
|
||||
if persona.builtin_persona and not builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
@@ -499,9 +494,6 @@ def upsert_persona(
|
||||
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
# The following update excludes `default`, `built-in`, and display priority.
|
||||
# Display priority is handled separately in the `display-priority` endpoint.
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
persona.name = name
|
||||
persona.description = description
|
||||
persona.num_chunks = num_chunks
|
||||
@@ -766,8 +758,6 @@ def get_prompt_by_name(
|
||||
if user and user.role != UserRole.ADMIN:
|
||||
stmt = stmt.where(Prompt.user_id == user.id)
|
||||
|
||||
# Order by ID to ensure consistent result when multiple prompts exist
|
||||
stmt = stmt.order_by(Prompt.id).limit(1)
|
||||
result = db_session.execute(stmt).scalar_one_or_none()
|
||||
return result
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
@@ -82,6 +83,7 @@ def insert_slack_channel_config(
|
||||
slack_bot_id: int,
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
) -> SlackChannelConfig:
|
||||
@@ -113,6 +115,7 @@ def insert_slack_channel_config(
|
||||
slack_bot_id=slack_bot_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=response_type,
|
||||
standard_answer_categories=existing_standard_answer_categories,
|
||||
enable_auto_filters=enable_auto_filters,
|
||||
)
|
||||
@@ -127,6 +130,7 @@ def update_slack_channel_config(
|
||||
slack_channel_config_id: int,
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
) -> SlackChannelConfig:
|
||||
@@ -166,6 +170,7 @@ def update_slack_channel_config(
|
||||
# will encounter `violates foreign key constraint` errors
|
||||
slack_channel_config.persona_id = persona_id
|
||||
slack_channel_config.channel_config = channel_config
|
||||
slack_channel_config.response_type = response_type
|
||||
slack_channel_config.standard_answer_categories = list(
|
||||
existing_standard_answer_categories
|
||||
)
|
||||
|
||||
@@ -59,12 +59,6 @@ class FileStore(ABC):
|
||||
Contents of the file and metadata dict
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def read_file_record(self, file_name: str) -> PGFileStore:
|
||||
"""
|
||||
Read the file record by the name
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_name: str) -> None:
|
||||
"""
|
||||
|
||||
@@ -18,12 +18,18 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
QuotesResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.llm.interfaces import LLM
|
||||
@@ -208,23 +214,18 @@ class Answer:
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
answer_handler: AnswerResponseHandler
|
||||
if self.answer_style_config.citation_config:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
answer_handler = QuotesResponseHandler(
|
||||
context_docs=search_result,
|
||||
)
|
||||
else:
|
||||
raise ValueError("No answer style config provided")
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
@@ -29,6 +30,7 @@ if TYPE_CHECKING:
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
|
||||
@@ -9,6 +9,9 @@ from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -67,29 +70,28 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
# No longer in use, remove later
|
||||
# class QuotesResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# is_json_prompt: bool = True,
|
||||
# ):
|
||||
# self.quotes_processor = QuotesProcessor(
|
||||
# context_docs=context_docs,
|
||||
# is_json_prompt=is_json_prompt,
|
||||
# )
|
||||
class QuotesResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.quotes_processor = QuotesProcessor(
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | None,
|
||||
# previous_response_items: list[BaseMessage],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# if response_item is None:
|
||||
# yield from self.quotes_processor.process_token(None)
|
||||
# return
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self.quotes_processor.process_token(None)
|
||||
return
|
||||
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item.content, str) else ""
|
||||
# )
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# yield from self.quotes_processor.process_token(content)
|
||||
yield from self.quotes_processor.process_token(content)
|
||||
|
||||
@@ -67,9 +67,9 @@ class CitationProcessor:
|
||||
if piece_that_comes_after == "\n" and in_code_block(self.llm_out):
|
||||
self.curr_segment = self.curr_segment.replace("```", "```plaintext")
|
||||
|
||||
citation_pattern = r"\[(\d+)\]|\[\[(\d+)\]\]" # [1], [[1]], etc.
|
||||
citation_pattern = r"\[(\d+)\]"
|
||||
citations_found = list(re.finditer(citation_pattern, self.curr_segment))
|
||||
possible_citation_pattern = r"(\[+\d*$)" # [1, [, [[, [[2, etc.
|
||||
possible_citation_pattern = r"(\[\d*$)" # [1, [, etc
|
||||
possible_citation_found = re.search(
|
||||
possible_citation_pattern, self.curr_segment
|
||||
)
|
||||
@@ -77,15 +77,13 @@ class CitationProcessor:
|
||||
if len(citations_found) == 0 and len(self.llm_out) - self.past_cite_count > 5:
|
||||
self.current_citations = []
|
||||
|
||||
result = ""
|
||||
result = "" # Initialize result here
|
||||
if citations_found and not in_code_block(self.llm_out):
|
||||
last_citation_end = 0
|
||||
length_to_add = 0
|
||||
while len(citations_found) > 0:
|
||||
citation = citations_found.pop(0)
|
||||
numerical_value = int(
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
numerical_value = int(citation.group(1))
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
@@ -133,6 +131,14 @@ class CitationProcessor:
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
# Replace the citation in the current segment
|
||||
start, end = citation.span()
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[{target_citation_num}]"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(target_citation_num)
|
||||
|
||||
@@ -143,7 +149,6 @@ class CitationProcessor:
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# THIS IS NO LONGER IN USE
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
@@ -6,10 +5,11 @@ from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
@@ -26,20 +26,6 @@ logger = setup_logger()
|
||||
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
|
||||
@@ -26,9 +26,7 @@ from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from danswer.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from danswer.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
from danswer.configs.model_configs import DISABLE_LITELLM_STREAMING
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from danswer.llm.interfaces import LLM
|
||||
@@ -163,9 +161,7 @@ def _convert_delta_to_message_chunk(
|
||||
|
||||
if role == "user":
|
||||
return HumanMessageChunk(content=content)
|
||||
# NOTE: if tool calls are present, then it's an assistant.
|
||||
# In Ollama, the role will be None for tool-calls
|
||||
elif role == "assistant" or tool_calls:
|
||||
elif role == "assistant":
|
||||
if tool_calls:
|
||||
tool_call = tool_calls[0]
|
||||
tool_name = tool_call.function.name or (curr_msg and curr_msg.name) or ""
|
||||
@@ -240,7 +236,6 @@ class DefaultMultiLLM(LLM):
|
||||
custom_config: dict[str, str] | None = None,
|
||||
extra_headers: dict[str, str] | None = None,
|
||||
extra_body: dict | None = LITELLM_EXTRA_BODY,
|
||||
model_kwargs: dict[str, Any] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
):
|
||||
self._timeout = timeout
|
||||
@@ -273,7 +268,7 @@ class DefaultMultiLLM(LLM):
|
||||
for k, v in custom_config.items():
|
||||
os.environ[k] = v
|
||||
|
||||
model_kwargs = model_kwargs or {}
|
||||
model_kwargs: dict[str, Any] = {}
|
||||
if extra_headers:
|
||||
model_kwargs.update({"extra_headers": extra_headers})
|
||||
if extra_body:
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.llm import fetch_default_provider
|
||||
@@ -14,20 +10,8 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
|
||||
For now, just using the GEN_AI_MODEL_FALLBACK_MAX_TOKENS value.
|
||||
TODO: allow model-specific values to be configured via the UI.
|
||||
"""
|
||||
return {"num_ctx": GEN_AI_MODEL_FALLBACK_MAX_TOKENS} if provider == "ollama" else {}
|
||||
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
llms: tuple[LLM, LLM],
|
||||
@@ -36,15 +20,11 @@ def get_main_llm_from_tuple(
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
persona: Persona,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
if persona is None:
|
||||
logger.warning("No persona provided, using default LLMs")
|
||||
return get_default_llms()
|
||||
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
@@ -79,7 +59,6 @@ def get_llms_for_persona(
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
@@ -137,13 +116,11 @@ def get_llm(
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
@@ -155,6 +132,5 @@ def get_llm(
|
||||
temperature=temperature,
|
||||
custom_config=custom_config,
|
||||
extra_headers=build_llm_extra_headers(additional_headers),
|
||||
model_kwargs=_build_extra_model_kwargs(provider),
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
@@ -386,62 +385,6 @@ def test_llm(llm: LLM) -> str | None:
|
||||
return error_msg
|
||||
|
||||
|
||||
def get_model_map() -> dict:
|
||||
starting_map = copy.deepcopy(cast(dict, litellm.model_cost))
|
||||
|
||||
# NOTE: we could add additional models here in the future,
|
||||
# but for now there is no point. Ollama allows the user to
|
||||
# to specify their desired max context window, and it's
|
||||
# unlikely to be standard across users even for the same model
|
||||
# (it heavily depends on their hardware). For now, we'll just
|
||||
# rely on GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this.
|
||||
# for model_name in [
|
||||
# "llama3.2",
|
||||
# "llama3.2:1b",
|
||||
# "llama3.2:3b",
|
||||
# "llama3.2:11b",
|
||||
# "llama3.2:90b",
|
||||
# ]:
|
||||
# starting_map[f"ollama/{model_name}"] = {
|
||||
# "max_tokens": 128000,
|
||||
# "max_input_tokens": 128000,
|
||||
# "max_output_tokens": 128000,
|
||||
# }
|
||||
|
||||
return starting_map
|
||||
|
||||
|
||||
def _strip_extra_provider_from_model_name(model_name: str) -> str:
|
||||
return model_name.split("/")[1] if "/" in model_name else model_name
|
||||
|
||||
|
||||
def _strip_colon_from_model_name(model_name: str) -> str:
|
||||
return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name
|
||||
|
||||
|
||||
def _find_model_obj(
|
||||
model_map: dict, provider: str, model_names: list[str | None]
|
||||
) -> dict | None:
|
||||
# Filter out None values and deduplicate model names
|
||||
filtered_model_names = [name for name in model_names if name]
|
||||
|
||||
# First try all model names with provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(f"{provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {provider}/{model_name}")
|
||||
return model_obj
|
||||
|
||||
# Then try all model names without provider prefix
|
||||
for model_name in filtered_model_names:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
return model_obj
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_llm_max_tokens(
|
||||
model_map: dict,
|
||||
model_name: str,
|
||||
@@ -454,22 +397,22 @@ def get_llm_max_tokens(
|
||||
return GEN_AI_MAX_TOKENS
|
||||
|
||||
try:
|
||||
extra_provider_stripped_model_name = _strip_extra_provider_from_model_name(
|
||||
model_name
|
||||
)
|
||||
model_obj = _find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
[
|
||||
model_name,
|
||||
# Remove leading extra provider. Usually for cases where user has a
|
||||
# customer model proxy which appends another prefix
|
||||
extra_provider_stripped_model_name,
|
||||
# remove :XXXX from the end, if present. Needed for ollama.
|
||||
_strip_colon_from_model_name(model_name),
|
||||
_strip_colon_from_model_name(extra_provider_stripped_model_name),
|
||||
],
|
||||
)
|
||||
model_obj = model_map.get(f"{model_provider}/{model_name}")
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_provider}/{model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_obj = model_map.get(model_name)
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name}")
|
||||
|
||||
if not model_obj:
|
||||
model_name_split = model_name.split("/")
|
||||
if len(model_name_split) > 1:
|
||||
model_obj = model_map.get(model_name_split[1])
|
||||
if model_obj:
|
||||
logger.debug(f"Using model object for {model_name_split[1]}")
|
||||
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
@@ -545,7 +488,7 @@ def get_max_input_tokens(
|
||||
# `model_cost` dict is a named public interface:
|
||||
# https://litellm.vercel.app/docs/completion/token_usage#7-model_cost
|
||||
# model_map is litellm.model_cost
|
||||
litellm_model_map = get_model_map()
|
||||
litellm_model_map = litellm.model_cost
|
||||
|
||||
input_toks = (
|
||||
get_llm_max_tokens(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user