mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
7 Commits
initial-im
...
text_view
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3481e7d443 | ||
|
|
b98caa933b | ||
|
|
458d5d59b3 | ||
|
|
aa7ab82dc8 | ||
|
|
962ea21150 | ||
|
|
a38373e07a | ||
|
|
d193c4f452 |
@@ -1,5 +1,5 @@
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from typing import Literal
|
||||
from typing import Any
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
import logging
|
||||
@@ -8,7 +8,6 @@ from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql import text
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from danswer.db.engine import build_connection_string
|
||||
@@ -36,18 +35,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
object: Any, name: str, type_: str, reflected: bool, compare_to: Any
|
||||
) -> bool:
|
||||
"""
|
||||
Determines whether a database object should be included in migrations.
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""Combine Search and Chat
|
||||
|
||||
Revision ID: 9f696734098f
|
||||
Revises: a8c2065484e6
|
||||
Create Date: 2024-11-27 15:32:19.694972
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "9f696734098f"
|
||||
down_revision = "a8c2065484e6"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column("chat_session", "description", nullable=True)
|
||||
op.drop_column("chat_session", "one_shot")
|
||||
op.drop_column("slack_channel_config", "response_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL")
|
||||
op.alter_column("chat_session", "description", nullable=False)
|
||||
op.add_column(
|
||||
"chat_session",
|
||||
sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()),
|
||||
)
|
||||
op.add_column(
|
||||
"slack_channel_config",
|
||||
sa.Column(
|
||||
"response_type", sa.String(), nullable=False, server_default="citations"
|
||||
),
|
||||
)
|
||||
@@ -1,6 +1,5 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
@@ -38,15 +37,8 @@ EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem,
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
name: str,
|
||||
type_: str,
|
||||
reflected: bool,
|
||||
compare_to: SchemaItem | None,
|
||||
) -> bool:
|
||||
|
||||
@@ -18,11 +18,6 @@ class ExternalAccess:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
"""
|
||||
This is just a class to wrap the external access and the document ID
|
||||
together. It's used for syncing document permissions to Redis.
|
||||
"""
|
||||
|
||||
external_access: ExternalAccess
|
||||
# The document ID
|
||||
doc_id: str
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.answer_query.nodes.answer_check import answer_check
|
||||
from danswer.agent_search.answer_query.nodes.answer_generation import answer_generation
|
||||
from danswer.agent_search.answer_query.nodes.format_answer import format_answer
|
||||
from danswer.agent_search.answer_query.states import AnswerQueryInput
|
||||
from danswer.agent_search.answer_query.states import AnswerQueryOutput
|
||||
from danswer.agent_search.answer_query.states import AnswerQueryState
|
||||
from danswer.agent_search.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
|
||||
|
||||
def answer_query_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=AnswerQueryState,
|
||||
input=AnswerQueryInput,
|
||||
output=AnswerQueryOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
expanded_retrieval = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="expanded_retrieval_for_initial_decomp",
|
||||
action=expanded_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_check",
|
||||
action=answer_check,
|
||||
)
|
||||
graph.add_node(
|
||||
node="answer_generation",
|
||||
action=answer_generation,
|
||||
)
|
||||
graph.add_node(
|
||||
node="format_answer",
|
||||
action=format_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expanded_retrieval_for_initial_decomp",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="expanded_retrieval_for_initial_decomp",
|
||||
end_key="answer_generation",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_generation",
|
||||
end_key="answer_check",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="answer_check",
|
||||
end_key="format_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="format_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.context.search.models import SearchRequest
|
||||
|
||||
graph = answer_query_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="Who made Excel and what other products did they make?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = AnswerQueryInput(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
query_to_answer="Who made Excel?",
|
||||
)
|
||||
output = compiled_graph.invoke(
|
||||
input=inputs,
|
||||
# debug=True,
|
||||
# subgraphs=True,
|
||||
)
|
||||
print(output)
|
||||
# for namespace, chunk in compiled_graph.stream(
|
||||
# input=inputs,
|
||||
# # debug=True,
|
||||
# subgraphs=True,
|
||||
# ):
|
||||
# print(namespace)
|
||||
# print(chunk)
|
||||
@@ -1,30 +0,0 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.answer_query.states import AnswerQueryState
|
||||
from danswer.agent_search.answer_query.states import QACheckOutput
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_CHECK_PROMPT
|
||||
|
||||
|
||||
def answer_check(state: AnswerQueryState) -> QACheckOutput:
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_CHECK_PROMPT.format(
|
||||
question=state["search_request"].query,
|
||||
base_answer=state["answer"],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
fast_llm = state["fast_llm"]
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
response_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
|
||||
return QACheckOutput(
|
||||
answer_quality=response_str,
|
||||
)
|
||||
@@ -1,32 +0,0 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.answer_query.states import AnswerQueryState
|
||||
from danswer.agent_search.answer_query.states import QAGenerationOutput
|
||||
from danswer.agent_search.shared_graph_utils.prompts import BASE_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def answer_generation(state: AnswerQueryState) -> QAGenerationOutput:
|
||||
query = state["query_to_answer"]
|
||||
docs = state["reordered_documents"]
|
||||
|
||||
print(f"Number of verified retrieval docs: {len(docs)}")
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=BASE_RAG_PROMPT.format(question=query, context=format_docs(docs))
|
||||
)
|
||||
]
|
||||
|
||||
fast_llm = state["fast_llm"]
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
answer_str = merge_message_runs(response, chunk_separator="")[0].content
|
||||
return QAGenerationOutput(
|
||||
answer=answer_str,
|
||||
)
|
||||
@@ -1,16 +0,0 @@
|
||||
from danswer.agent_search.answer_query.states import AnswerQueryOutput
|
||||
from danswer.agent_search.answer_query.states import AnswerQueryState
|
||||
from danswer.agent_search.answer_query.states import SearchAnswerResults
|
||||
|
||||
|
||||
def format_answer(state: AnswerQueryState) -> AnswerQueryOutput:
|
||||
return AnswerQueryOutput(
|
||||
decomp_answer_results=[
|
||||
SearchAnswerResults(
|
||||
query=state["query_to_answer"],
|
||||
quality=state["answer_quality"],
|
||||
answer=state["answer"],
|
||||
documents=state["reordered_documents"],
|
||||
)
|
||||
],
|
||||
)
|
||||
@@ -1,45 +0,0 @@
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.agent_search.core_state import PrimaryState
|
||||
from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class SearchAnswerResults(BaseModel):
|
||||
query: str
|
||||
answer: str
|
||||
quality: str
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class QACheckOutput(TypedDict, total=False):
|
||||
answer_quality: str
|
||||
|
||||
|
||||
class QAGenerationOutput(TypedDict, total=False):
|
||||
answer: str
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(TypedDict):
|
||||
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class AnswerQueryState(
|
||||
PrimaryState,
|
||||
QACheckOutput,
|
||||
QAGenerationOutput,
|
||||
ExpandedRetrievalOutput,
|
||||
total=True,
|
||||
):
|
||||
query_to_answer: str
|
||||
|
||||
|
||||
class AnswerQueryInput(PrimaryState, total=True):
|
||||
query_to_answer: str
|
||||
|
||||
|
||||
class AnswerQueryOutput(TypedDict):
|
||||
decomp_answer_results: list[SearchAnswerResults]
|
||||
@@ -1,15 +0,0 @@
|
||||
from typing import TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
class PrimaryState(TypedDict, total=False):
|
||||
search_request: SearchRequest
|
||||
primary_llm: LLM
|
||||
fast_llm: LLM
|
||||
# a single session for the entire agent search
|
||||
# is fine if we are only reading
|
||||
db_session: Session
|
||||
@@ -1,114 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.main.states import MainState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import COMBINED_CONTEXT
|
||||
from danswer.agent_search.shared_graph_utils.prompts import MODIFIED_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
from danswer.agent_search.shared_graph_utils.utils import normalize_whitespace
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def deep_answer_generation(state: MainState) -> dict[str, Any]:
|
||||
"""
|
||||
Generate answer
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with re-phrased question
|
||||
"""
|
||||
print("---DEEP GENERATE---")
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
deep_answer_context = state["core_answer_dynamic_context"]
|
||||
|
||||
print(f"Number of verified retrieval docs - deep: {len(docs)}")
|
||||
|
||||
combined_context = normalize_whitespace(
|
||||
COMBINED_CONTEXT.format(
|
||||
deep_answer_context=deep_answer_context, formated_docs=format_docs(docs)
|
||||
)
|
||||
)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=MODIFIED_RAG_PROMPT.format(
|
||||
question=question, combined_context=combined_context
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
return {
|
||||
"deep_answer": response.content,
|
||||
}
|
||||
|
||||
|
||||
def final_stuff(state: MainState) -> dict[str, Any]:
|
||||
"""
|
||||
Invokes the agent model to generate a response based on the current state. Given
|
||||
the question, it will decide to retrieve using the retriever tool, or simply end.
|
||||
|
||||
Args:
|
||||
state (messages): The current state
|
||||
|
||||
Returns:
|
||||
dict: The updated state with the agent response appended to messages
|
||||
"""
|
||||
print("---FINAL---")
|
||||
|
||||
messages = state["log_messages"]
|
||||
time_ordered_messages = [x.pretty_repr() for x in messages]
|
||||
time_ordered_messages.sort()
|
||||
|
||||
print("Message Log:")
|
||||
print("\n".join(time_ordered_messages))
|
||||
|
||||
initial_sub_qas = state["initial_sub_qas"]
|
||||
initial_sub_qa_list = []
|
||||
for initial_sub_qa in initial_sub_qas:
|
||||
if initial_sub_qa["sub_answer_check"] == "yes":
|
||||
initial_sub_qa_list.append(
|
||||
f' Question:\n {initial_sub_qa["sub_question"]}\n --\n Answer:\n {initial_sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
initial_sub_qa_context = "\n".join(initial_sub_qa_list)
|
||||
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Initial Answered Sub Questions:\n{initial_sub_qa_context}")
|
||||
print("--------------------------------")
|
||||
|
||||
if not state.get("deep_answer"):
|
||||
print("No Deep Answer was required")
|
||||
return {}
|
||||
|
||||
deep_answer = state["deep_answer"]
|
||||
sub_qas = state["sub_qas"]
|
||||
sub_qa_list = []
|
||||
for sub_qa in sub_qas:
|
||||
if sub_qa["sub_answer_check"] == "yes":
|
||||
sub_qa_list.append(
|
||||
f' Question:\n {sub_qa["sub_question"]}\n --\n Answer:\n {sub_qa["sub_answer"]}\n -----'
|
||||
)
|
||||
|
||||
sub_qa_context = "\n".join(sub_qa_list)
|
||||
|
||||
print(f"Final Base Answer:\n{base_answer}")
|
||||
print("--------------------------------")
|
||||
print(f"Final Deep Answer:\n{deep_answer}")
|
||||
print("--------------------------------")
|
||||
print("Sub Questions and Answers:")
|
||||
print(sub_qa_context)
|
||||
|
||||
return {}
|
||||
@@ -1,78 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.main.states import MainState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import DEEP_DECOMPOSE_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_entity_term_extraction
|
||||
from danswer.agent_search.shared_graph_utils.utils import generate_log_message
|
||||
|
||||
|
||||
def decompose(state: MainState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
node_start_time = datetime.now()
|
||||
|
||||
question = state["original_question"]
|
||||
base_answer = state["base_answer"]
|
||||
|
||||
# get the entity term extraction dict and properly format it
|
||||
entity_term_extraction_dict = state["retrieved_entities_relationships"][
|
||||
"retrieved_entities_relationships"
|
||||
]
|
||||
|
||||
entity_term_extraction_str = format_entity_term_extraction(
|
||||
entity_term_extraction_dict
|
||||
)
|
||||
|
||||
initial_question_answers = state["initial_sub_qas"]
|
||||
|
||||
addressed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "yes"
|
||||
]
|
||||
failed_question_list = [
|
||||
x["sub_question"]
|
||||
for x in initial_question_answers
|
||||
if x["sub_answer_check"] == "no"
|
||||
]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=DEEP_DECOMPOSE_PROMPT.format(
|
||||
question=question,
|
||||
entity_term_extraction_str=entity_term_extraction_str,
|
||||
base_answer=base_answer,
|
||||
answered_sub_questions="\n - ".join(addressed_question_list),
|
||||
failed_sub_questions="\n - ".join(failed_question_list),
|
||||
),
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", response.pretty_repr())
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
sub_questions_dict = {}
|
||||
for sub_question_nr, sub_question_dict in enumerate(
|
||||
parsed_response["sub_questions"]
|
||||
):
|
||||
sub_question_dict["answered"] = False
|
||||
sub_question_dict["verified"] = False
|
||||
sub_questions_dict[sub_question_nr] = sub_question_dict
|
||||
|
||||
return {
|
||||
"decomposed_sub_questions_dict": sub_questions_dict,
|
||||
"log_messages": generate_log_message(
|
||||
message="deep - decompose",
|
||||
node_start_time=node_start_time,
|
||||
graph_start_time=state["graph_start_time"],
|
||||
),
|
||||
}
|
||||
@@ -1,40 +0,0 @@
|
||||
import json
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.main.states import MainState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import ENTITY_TERM_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def entity_term_extraction(state: MainState) -> dict[str, Any]:
|
||||
"""Extract entities and terms from the question and context"""
|
||||
|
||||
question = state["original_question"]
|
||||
docs = state["deduped_retrieval_docs"]
|
||||
|
||||
doc_context = format_docs(docs)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=ENTITY_TERM_PROMPT.format(question=question, context=doc_context),
|
||||
)
|
||||
]
|
||||
fast_llm = state["fast_llm"]
|
||||
# Grader
|
||||
llm_response_list = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
cleaned_response = re.sub(r"```json\n|\n```", "", llm_response)
|
||||
parsed_response = json.loads(cleaned_response)
|
||||
|
||||
return {
|
||||
"retrieved_entities_relationships": parsed_response,
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.main.states import MainState
|
||||
|
||||
|
||||
# aggregate sub questions and answers
|
||||
def sub_qa_level_aggregator(state: MainState) -> dict[str, Any]:
|
||||
sub_qas = state["sub_qas"]
|
||||
|
||||
dynamic_context_list = [
|
||||
"Below you will find useful information to answer the original question:"
|
||||
]
|
||||
checked_sub_qas = []
|
||||
|
||||
for core_answer_sub_qa in sub_qas:
|
||||
question = core_answer_sub_qa["sub_question"]
|
||||
answer = core_answer_sub_qa["sub_answer"]
|
||||
verified = core_answer_sub_qa["sub_answer_check"]
|
||||
|
||||
if verified == "yes":
|
||||
dynamic_context_list.append(
|
||||
f"Question:\n{question}\n\nAnswer:\n{answer}\n\n---\n\n"
|
||||
)
|
||||
checked_sub_qas.append({"sub_question": question, "sub_answer": answer})
|
||||
dynamic_context = "\n".join(dynamic_context_list)
|
||||
|
||||
return {
|
||||
"core_answer_dynamic_context": dynamic_context,
|
||||
"checked_sub_qas": checked_sub_qas,
|
||||
}
|
||||
@@ -1,19 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.agent_search.main.states import MainState
|
||||
|
||||
|
||||
def sub_qa_manager(state: MainState) -> dict[str, Any]:
|
||||
""" """
|
||||
|
||||
sub_questions_dict = state["decomposed_sub_questions_dict"]
|
||||
|
||||
sub_questions = {}
|
||||
|
||||
for sub_question_nr, sub_question_dict in sub_questions_dict.items():
|
||||
sub_questions[sub_question_nr] = sub_question_dict["sub_question"]
|
||||
|
||||
return {
|
||||
"sub_questions": sub_questions,
|
||||
"num_new_question_iterations": 0,
|
||||
}
|
||||
@@ -1,44 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import RetrieveInput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
from danswer.agent_search.shared_graph_utils.prompts import REWRITE_PROMPT_MULTI
|
||||
from danswer.llm.interfaces import LLM
|
||||
|
||||
|
||||
def parallel_retrieval_edge(state: ExpandedRetrievalInput) -> list[Send | Hashable]:
|
||||
print(f"parallel_retrieval_edge state: {state.keys()}")
|
||||
|
||||
# This should be better...
|
||||
question = state.get("query_to_answer") or state["search_request"].query
|
||||
llm: LLM = state["fast_llm"]
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=REWRITE_PROMPT_MULTI.format(question=question),
|
||||
)
|
||||
]
|
||||
llm_response_list = list(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
llm_response = merge_message_runs(llm_response_list, chunk_separator="")[0].content
|
||||
|
||||
print(f"llm_response: {llm_response}")
|
||||
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
|
||||
print(f"rewritten_queries: {rewritten_queries}")
|
||||
|
||||
return [
|
||||
Send(
|
||||
"doc_retrieval",
|
||||
RetrieveInput(query_to_retrieve=query, **state),
|
||||
)
|
||||
for query in rewritten_queries
|
||||
]
|
||||
@@ -1,88 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.edges import parallel_retrieval_edge
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_reranking import doc_reranking
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_retrieval import doc_retrieval
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_verification import (
|
||||
doc_verification,
|
||||
)
|
||||
from danswer.agent_search.expanded_retrieval.nodes.verification_kickoff import (
|
||||
verification_kickoff,
|
||||
)
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalInput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
|
||||
|
||||
def expanded_retrieval_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=ExpandedRetrievalState,
|
||||
input=ExpandedRetrievalInput,
|
||||
output=ExpandedRetrievalOutput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="doc_retrieval",
|
||||
action=doc_retrieval,
|
||||
)
|
||||
graph.add_node(
|
||||
node="verification_kickoff",
|
||||
action=verification_kickoff,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_verification",
|
||||
action=doc_verification,
|
||||
)
|
||||
graph.add_node(
|
||||
node="doc_reranking",
|
||||
action=doc_reranking,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
|
||||
graph.add_conditional_edges(
|
||||
source=START,
|
||||
path=parallel_retrieval_edge,
|
||||
path_map=["doc_retrieval"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_retrieval",
|
||||
end_key="verification_kickoff",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_verification",
|
||||
end_key="doc_reranking",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="doc_reranking",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.context.search.models import SearchRequest
|
||||
|
||||
graph = expanded_retrieval_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="Who made Excel and what other products did they make?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = ExpandedRetrievalInput(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
query_to_answer="Who made Excel?",
|
||||
)
|
||||
for thing in compiled_graph.stream(inputs, debug=True):
|
||||
print(thing)
|
||||
@@ -1,11 +0,0 @@
|
||||
from danswer.agent_search.expanded_retrieval.states import DocRerankingOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
|
||||
|
||||
def doc_reranking(state: ExpandedRetrievalState) -> DocRerankingOutput:
|
||||
print(f"doc_reranking state: {state.keys()}")
|
||||
|
||||
verified_documents = state["verified_documents"]
|
||||
reranked_documents = verified_documents
|
||||
|
||||
return DocRerankingOutput(reranked_documents=reranked_documents)
|
||||
@@ -1,47 +0,0 @@
|
||||
from danswer.agent_search.expanded_retrieval.states import DocRetrievalOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
|
||||
|
||||
class RetrieveInput(ExpandedRetrievalState):
|
||||
query_to_retrieve: str
|
||||
|
||||
|
||||
def doc_retrieval(state: RetrieveInput) -> DocRetrievalOutput:
|
||||
# def doc_retrieval(state: RetrieveInput) -> Command[Literal["doc_verification"]]:
|
||||
"""
|
||||
Retrieve documents
|
||||
|
||||
Args:
|
||||
state (dict): The current graph state
|
||||
|
||||
Returns:
|
||||
state (dict): New key added to state, documents, that contains retrieved documents
|
||||
"""
|
||||
print(f"doc_retrieval state: {state.keys()}")
|
||||
|
||||
state["query_to_retrieve"]
|
||||
|
||||
documents: list[InferenceSection] = []
|
||||
llm = state["primary_llm"]
|
||||
fast_llm = state["fast_llm"]
|
||||
# db_session = state["db_session"]
|
||||
query_to_retrieve = state["search_request"].query
|
||||
with get_session_context_manager() as db_session1:
|
||||
documents = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query_to_retrieve,
|
||||
),
|
||||
user=None,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session1,
|
||||
).reranked_sections
|
||||
|
||||
print(f"retrieved documents: {len(documents)}")
|
||||
return DocRetrievalOutput(
|
||||
retrieved_documents=documents,
|
||||
)
|
||||
@@ -1,60 +0,0 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import merge_message_runs
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.states import DocVerificationOutput
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
from danswer.agent_search.shared_graph_utils.models import BinaryDecision
|
||||
from danswer.agent_search.shared_graph_utils.prompts import VERIFIER_PROMPT
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class DocVerificationInput(ExpandedRetrievalState, total=True):
|
||||
doc_to_verify: InferenceSection
|
||||
|
||||
|
||||
def doc_verification(state: DocVerificationInput) -> DocVerificationOutput:
|
||||
"""
|
||||
Check whether the document is relevant for the original user question
|
||||
|
||||
Args:
|
||||
state (VerifierState): The current state
|
||||
|
||||
Returns:
|
||||
dict: ict: The updated state with the final decision
|
||||
"""
|
||||
|
||||
print(f"doc_verification state: {state.keys()}")
|
||||
|
||||
original_query = state["search_request"].query
|
||||
doc_to_verify = state["doc_to_verify"]
|
||||
document_content = doc_to_verify.combined_content
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=VERIFIER_PROMPT.format(
|
||||
question=original_query, document_content=document_content
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
fast_llm = state["fast_llm"]
|
||||
response = list(
|
||||
fast_llm.stream(
|
||||
prompt=msg,
|
||||
)
|
||||
)
|
||||
|
||||
response_string = merge_message_runs(response, chunk_separator="")[0].content
|
||||
# Convert string response to proper dictionary format
|
||||
decision_dict = {"decision": response_string.lower()}
|
||||
formatted_response = BinaryDecision.model_validate(decision_dict)
|
||||
|
||||
print(f"Verdict: {formatted_response.decision}")
|
||||
|
||||
verified_documents = []
|
||||
if formatted_response.decision == "yes":
|
||||
verified_documents.append(doc_to_verify)
|
||||
|
||||
return DocVerificationOutput(
|
||||
verified_documents=verified_documents,
|
||||
)
|
||||
@@ -1,27 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from langgraph.types import Command
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.expanded_retrieval.nodes.doc_verification import (
|
||||
DocVerificationInput,
|
||||
)
|
||||
from danswer.agent_search.expanded_retrieval.states import ExpandedRetrievalState
|
||||
|
||||
|
||||
def verification_kickoff(
|
||||
state: ExpandedRetrievalState,
|
||||
) -> Command[Literal["doc_verification"]]:
|
||||
print(f"verification_kickoff state: {state.keys()}")
|
||||
|
||||
documents = state["retrieved_documents"]
|
||||
return Command(
|
||||
update={},
|
||||
goto=[
|
||||
Send(
|
||||
node="doc_verification",
|
||||
arg=DocVerificationInput(doc_to_verify=doc, **state),
|
||||
)
|
||||
for doc in documents
|
||||
],
|
||||
)
|
||||
@@ -1,36 +0,0 @@
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from danswer.agent_search.core_state import PrimaryState
|
||||
from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class DocRetrievalOutput(TypedDict, total=False):
|
||||
retrieved_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class DocVerificationOutput(TypedDict, total=False):
|
||||
verified_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class DocRerankingOutput(TypedDict, total=False):
|
||||
reranked_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
|
||||
|
||||
class ExpandedRetrievalState(
|
||||
PrimaryState,
|
||||
DocRetrievalOutput,
|
||||
DocVerificationOutput,
|
||||
DocRerankingOutput,
|
||||
total=True,
|
||||
):
|
||||
query_to_answer: str
|
||||
|
||||
|
||||
class ExpandedRetrievalInput(PrimaryState, total=True):
|
||||
query_to_answer: str
|
||||
|
||||
|
||||
class ExpandedRetrievalOutput(TypedDict):
|
||||
reordered_documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
@@ -1,61 +0,0 @@
|
||||
from collections.abc import Hashable
|
||||
|
||||
from langgraph.types import Send
|
||||
|
||||
from danswer.agent_search.answer_query.states import AnswerQueryInput
|
||||
from danswer.agent_search.main.states import MainState
|
||||
|
||||
|
||||
def parallelize_decompozed_answer_queries(state: MainState) -> list[Send | Hashable]:
|
||||
return [
|
||||
Send(
|
||||
"answer_query",
|
||||
AnswerQueryInput(
|
||||
**state,
|
||||
query_to_answer=query,
|
||||
),
|
||||
)
|
||||
for query in state["initial_decomp_queries"]
|
||||
]
|
||||
|
||||
|
||||
# def continue_to_answer_sub_questions(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
# # Routes re-written queries to the (parallel) retrieval steps
|
||||
# # Notice the 'Send()' API that takes care of the parallelization
|
||||
# return [
|
||||
# Send(
|
||||
# "sub_answers_graph",
|
||||
# ResearchQAState(
|
||||
# sub_question=sub_question["sub_question_str"],
|
||||
# sub_question_nr=sub_question["sub_question_nr"],
|
||||
# graph_start_time=state["graph_start_time"],
|
||||
# primary_llm=state["primary_llm"],
|
||||
# fast_llm=state["fast_llm"],
|
||||
# ),
|
||||
# )
|
||||
# for sub_question in state["sub_questions"]
|
||||
# ]
|
||||
|
||||
|
||||
# def continue_to_deep_answer(state: QAState) -> Union[Hashable, list[Hashable]]:
|
||||
# print("---GO TO DEEP ANSWER OR END---")
|
||||
|
||||
# base_answer = state["base_answer"]
|
||||
|
||||
# question = state["original_question"]
|
||||
|
||||
# BASE_CHECK_MESSAGE = [
|
||||
# HumanMessage(
|
||||
# content=BASE_CHECK_PROMPT.format(question=question, base_answer=base_answer)
|
||||
# )
|
||||
# ]
|
||||
|
||||
# model = state["fast_llm"]
|
||||
# response = model.invoke(BASE_CHECK_MESSAGE)
|
||||
|
||||
# print(f"CAN WE CONTINUE W/O GENERATING A DEEP ANSWER? - {response.pretty_repr()}")
|
||||
|
||||
# if response.pretty_repr() == "no":
|
||||
# return "decompose"
|
||||
# else:
|
||||
# return "end"
|
||||
@@ -1,98 +0,0 @@
|
||||
from langgraph.graph import END
|
||||
from langgraph.graph import START
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from danswer.agent_search.answer_query.graph_builder import answer_query_graph_builder
|
||||
from danswer.agent_search.expanded_retrieval.graph_builder import (
|
||||
expanded_retrieval_graph_builder,
|
||||
)
|
||||
from danswer.agent_search.main.edges import parallelize_decompozed_answer_queries
|
||||
from danswer.agent_search.main.nodes.base_decomp import main_decomp_base
|
||||
from danswer.agent_search.main.nodes.generate_initial_answer import (
|
||||
generate_initial_answer,
|
||||
)
|
||||
from danswer.agent_search.main.states import MainInput
|
||||
from danswer.agent_search.main.states import MainState
|
||||
|
||||
|
||||
def main_graph_builder() -> StateGraph:
|
||||
graph = StateGraph(
|
||||
state_schema=MainState,
|
||||
input=MainInput,
|
||||
)
|
||||
|
||||
### Add nodes ###
|
||||
|
||||
graph.add_node(
|
||||
node="base_decomp",
|
||||
action=main_decomp_base,
|
||||
)
|
||||
answer_query_subgraph = answer_query_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="answer_query",
|
||||
action=answer_query_subgraph,
|
||||
)
|
||||
expanded_retrieval_subgraph = expanded_retrieval_graph_builder().compile()
|
||||
graph.add_node(
|
||||
node="expanded_retrieval",
|
||||
action=expanded_retrieval_subgraph,
|
||||
)
|
||||
graph.add_node(
|
||||
node="generate_initial_answer",
|
||||
action=generate_initial_answer,
|
||||
)
|
||||
|
||||
### Add edges ###
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="expanded_retrieval",
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key=START,
|
||||
end_key="base_decomp",
|
||||
)
|
||||
graph.add_conditional_edges(
|
||||
source="base_decomp",
|
||||
path=parallelize_decompozed_answer_queries,
|
||||
path_map=["answer_query"],
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key=["answer_query", "expanded_retrieval"],
|
||||
end_key="generate_initial_answer",
|
||||
)
|
||||
graph.add_edge(
|
||||
start_key="generate_initial_answer",
|
||||
end_key=END,
|
||||
)
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.context.search.models import SearchRequest
|
||||
|
||||
graph = main_graph_builder()
|
||||
compiled_graph = graph.compile()
|
||||
primary_llm, fast_llm = get_default_llms()
|
||||
search_request = SearchRequest(
|
||||
query="If i am familiar with the function that I need, how can I type it into a cell?",
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
inputs = MainInput(
|
||||
search_request=search_request,
|
||||
primary_llm=primary_llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
)
|
||||
for thing in compiled_graph.stream(
|
||||
input=inputs,
|
||||
# stream_mode="debug",
|
||||
# debug=True,
|
||||
subgraphs=True,
|
||||
):
|
||||
# print(thing)
|
||||
print()
|
||||
print()
|
||||
@@ -1,31 +0,0 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.main.states import BaseDecompOutput
|
||||
from danswer.agent_search.main.states import MainState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import INITIAL_DECOMPOSITION_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import clean_and_parse_list_string
|
||||
|
||||
|
||||
def main_decomp_base(state: MainState) -> BaseDecompOutput:
|
||||
question = state["search_request"].query
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_DECOMPOSITION_PROMPT.format(question=question),
|
||||
)
|
||||
]
|
||||
|
||||
# Get the rewritten queries in a defined format
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
|
||||
content = response.pretty_repr()
|
||||
list_of_subquestions = clean_and_parse_list_string(content)
|
||||
|
||||
decomp_list: list[str] = [
|
||||
sub_question["sub_question"].strip() for sub_question in list_of_subquestions
|
||||
]
|
||||
|
||||
return BaseDecompOutput(
|
||||
initial_decomp_queries=decomp_list,
|
||||
)
|
||||
@@ -1,53 +0,0 @@
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.agent_search.main.states import InitialAnswerOutput
|
||||
from danswer.agent_search.main.states import MainState
|
||||
from danswer.agent_search.shared_graph_utils.prompts import INITIAL_RAG_PROMPT
|
||||
from danswer.agent_search.shared_graph_utils.utils import format_docs
|
||||
|
||||
|
||||
def generate_initial_answer(state: MainState) -> InitialAnswerOutput:
|
||||
print("---GENERATE INITIAL---")
|
||||
|
||||
question = state["search_request"].query
|
||||
docs = state["documents"]
|
||||
|
||||
decomp_answer_results = state["decomp_answer_results"]
|
||||
|
||||
good_qa_list: list[str] = []
|
||||
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE = """
|
||||
Sub-Question:\n - {sub_question}\n --\nAnswer:\n - {sub_answer}\n\n
|
||||
"""
|
||||
for decomp_answer_result in decomp_answer_results:
|
||||
if (
|
||||
decomp_answer_result.quality.lower() == "yes"
|
||||
and len(decomp_answer_result.answer) > 0
|
||||
and decomp_answer_result.answer != "I don't know"
|
||||
):
|
||||
good_qa_list.append(
|
||||
_SUB_QUESTION_ANSWER_TEMPLATE.format(
|
||||
sub_question=decomp_answer_result.query,
|
||||
sub_answer=decomp_answer_result.answer,
|
||||
)
|
||||
)
|
||||
|
||||
sub_question_answer_str = "\n\n------\n\n".join(good_qa_list)
|
||||
|
||||
msg = [
|
||||
HumanMessage(
|
||||
content=INITIAL_RAG_PROMPT.format(
|
||||
question=question,
|
||||
context=format_docs(docs),
|
||||
answered_sub_questions=sub_question_answer_str,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
# Grader
|
||||
model = state["fast_llm"]
|
||||
response = model.invoke(msg)
|
||||
answer = response.pretty_repr()
|
||||
|
||||
print(answer)
|
||||
return InitialAnswerOutput(initial_answer=answer)
|
||||
@@ -1,37 +0,0 @@
|
||||
from operator import add
|
||||
from typing import Annotated
|
||||
from typing import TypedDict
|
||||
|
||||
from danswer.agent_search.answer_query.states import SearchAnswerResults
|
||||
from danswer.agent_search.core_state import PrimaryState
|
||||
from danswer.agent_search.shared_graph_utils.operators import dedup_inference_sections
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
class BaseDecompOutput(TypedDict, total=False):
|
||||
initial_decomp_queries: list[str]
|
||||
|
||||
|
||||
class InitialAnswerOutput(TypedDict, total=False):
|
||||
initial_answer: str
|
||||
|
||||
|
||||
class MainState(
|
||||
PrimaryState,
|
||||
BaseDecompOutput,
|
||||
InitialAnswerOutput,
|
||||
total=True,
|
||||
):
|
||||
documents: Annotated[list[InferenceSection], dedup_inference_sections]
|
||||
decomp_answer_results: Annotated[list[SearchAnswerResults], add]
|
||||
|
||||
|
||||
class MainInput(PrimaryState, total=True):
|
||||
pass
|
||||
|
||||
|
||||
class MainOutput(TypedDict):
|
||||
"""
|
||||
This is not used because defining the output only matters for filtering the output of
|
||||
a .invoke() call but we are streaming so we just yield the entire state.
|
||||
"""
|
||||
@@ -1,27 +0,0 @@
|
||||
from danswer.agent_search.primary_graph.graph_builder import build_core_graph
|
||||
from danswer.llm.answering.answer import AnswerStream
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def run_graph(
|
||||
query: str,
|
||||
llm: LLM,
|
||||
tools: list[Tool],
|
||||
) -> AnswerStream:
|
||||
graph = build_core_graph()
|
||||
|
||||
inputs = {
|
||||
"original_query": query,
|
||||
"messages": [],
|
||||
"tools": tools,
|
||||
"llm": llm,
|
||||
}
|
||||
compiled_graph = graph.compile()
|
||||
output = compiled_graph.invoke(input=inputs)
|
||||
yield from output
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pass
|
||||
# run_graph("What is the capital of France?", llm, [])
|
||||
@@ -1,12 +0,0 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Pydantic models for structured outputs
|
||||
class RewrittenQueries(BaseModel):
|
||||
rewritten_queries: list[str]
|
||||
|
||||
|
||||
class BinaryDecision(BaseModel):
|
||||
decision: Literal["yes", "no"]
|
||||
@@ -1,9 +0,0 @@
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.answering.prune_and_merge import _merge_sections
|
||||
|
||||
|
||||
def dedup_inference_sections(
|
||||
list1: list[InferenceSection], list2: list[InferenceSection]
|
||||
) -> list[InferenceSection]:
|
||||
deduped = _merge_sections(list1 + list2)
|
||||
return deduped
|
||||
@@ -1,427 +0,0 @@
|
||||
REWRITE_PROMPT_MULTI_ORIGINAL = """ \n
|
||||
Please convert an initial user question into a 2-3 more appropriate short and pointed search queries for retrievel from a
|
||||
document store. Particularly, try to think about resolving ambiguities and make the search queries more specific,
|
||||
enabling the system to search more broadly.
|
||||
Also, try to make the search queries not redundant, i.e. not too similar! \n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Formulate the queries separated by '--' (Do not say 'Query 1: ...', just write the querytext): """
|
||||
|
||||
REWRITE_PROMPT_MULTI = """ \n
|
||||
Please create a list of 2-3 sample documents that could answer an original question. Each document
|
||||
should be about as long as the original question. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Formulate the sample documents separated by '--' (Do not say 'Document 1: ...', just write the text): """
|
||||
|
||||
BASE_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the context provided below - and only the
|
||||
provided context - to answer the question. If you don't know the answer or if the provided context is
|
||||
empty, just say "I don't know". Do not use your internal knowledge!
|
||||
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
answer concise.\nQuestion:\n {question} \nContext:\n {context} \n\n
|
||||
\n\n
|
||||
Answer:"""
|
||||
|
||||
BASE_CHECK_PROMPT = """ \n
|
||||
Please check whether 1) the suggested answer seems to fully address the original question AND 2)the
|
||||
original question requests a simple, factual answer, and there are no ambiguities, judgements,
|
||||
aggregations, or any other complications that may require extra context. (I.e., if the question is
|
||||
somewhat addressed, but the answer would benefit from more context, then answer with 'no'.)
|
||||
|
||||
Please only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the proposed answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
VERIFIER_PROMPT = """ \n
|
||||
Please check whether the document seems to be relevant for the answer of the question. Please
|
||||
only answer with 'yes' or 'no' \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
Here is the document text:
|
||||
\n ------- \n
|
||||
{document_content}
|
||||
\n ------- \n
|
||||
Please answer with yes or no:"""
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT_BASIC = """ \n
|
||||
Please decompose an initial user question into not more than 4 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of subquestions:
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
REWRITE_PROMPT_SINGLE = """ \n
|
||||
Please convert an initial user question into a more appropriate search query for retrievel from a
|
||||
document store. \n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Formulate the query: """
|
||||
|
||||
MODIFIED_RAG_PROMPT = """You are an assistant for question-answering tasks. Use the context provided below
|
||||
- and only this context - to answer the question. If you don't know the answer, just say "I don't know".
|
||||
Use three sentences maximum and keep the answer concise.
|
||||
Pay also particular attention to the sub-questions and their answers, at least it may enrich the answer.
|
||||
Again, only use the provided context and do not use your internal knowledge! If you cannot answer the
|
||||
question based on the context, say "I don't know". It is a matter of life and death that you do NOT
|
||||
use your internal knowledge, just the provided information!
|
||||
|
||||
\nQuestion: {question}
|
||||
\nContext: {combined_context} \n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ORIG_DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 3-5 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DEEP_DECOMPOSE_PROMPT = """ \n
|
||||
An initial user question needs to be answered. An initial answer has been provided but it wasn't quite
|
||||
good enough. Also, some sub-questions had been answered and this information has been used to provide
|
||||
the initial answer. Some other subquestions may have been suggested based on little knowledge, but they
|
||||
were not directly answerable. Also, some entities, relationships and terms are givenm to you so that
|
||||
you have an idea of how the avaiolable data looks like.
|
||||
|
||||
Your role is to generate 4-6 new sub-questions that would help to answer the initial question,
|
||||
considering:
|
||||
|
||||
1) The initial question
|
||||
2) The initial answer that was found to be unsatisfactory
|
||||
3) The sub-questions that were answered
|
||||
4) The sub-questions that were suggested but not answered
|
||||
5) The entities, relationships and terms that were extracted from the context
|
||||
|
||||
The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question, but in a way that does
|
||||
not duplicate questions that were already tried.
|
||||
|
||||
Additional Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
resolve ambiguities, or address shortcoming of the initial answer
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please also provide a search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Here is the initial sub-optimal answer:
|
||||
\n ------- \n
|
||||
{base_answer}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were answered:
|
||||
\n ------- \n
|
||||
{answered_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
Here are the sub-questions that were suggested but not answered:
|
||||
\n ------- \n
|
||||
{failed_sub_questions}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Again, please find questions that are NOT overlapping too much with the already answered
|
||||
sub-questions or those that already were suggested and failed.
|
||||
In other words - what can we try in addition to what has been tried so far?
|
||||
|
||||
Generate the list of json dictionaries with the following format:
|
||||
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
DECOMPOSE_PROMPT = """ \n
|
||||
For an initial user question, please generate at 5-10 individual sub-questions whose answers would help
|
||||
\n to answer the initial question. The individual questions should be answerable by a good RAG system.
|
||||
So a good idea would be to \n use the sub-questions to resolve ambiguities and/or to separate the
|
||||
question for different entities that may be involved in the original question.
|
||||
|
||||
In order to arrive at meaningful sub-questions, please also consider the context retrieved from the
|
||||
document store, expressed as entities, relationships and terms. You can also think about the types
|
||||
mentioned in brackets
|
||||
|
||||
Guidelines:
|
||||
- The sub-questions should be specific to the question and provide richer context for the question,
|
||||
and or resolve ambiguities
|
||||
- Each sub-question - when answered - should be relevant for the answer to the original question
|
||||
- The sub-questions should be free from comparisions, ambiguities,judgements, aggregations, or any
|
||||
other complications that may require extra context.
|
||||
- The sub-questions MUST have the full context of the original question so that it can be executed by
|
||||
a RAG system independently without the original question available
|
||||
(Example:
|
||||
- initial question: "What is the capital of France?"
|
||||
- bad sub-question: "What is the name of the river there?"
|
||||
- good sub-question: "What is the name of the river that flows through Paris?"
|
||||
- For each sub-question, please provide a short explanation for why it is a good sub-question. So
|
||||
generate a list of dictionaries with the following format:
|
||||
[{{"sub_question": <sub-question>, "explanation": <explanation>, "search_term": <rewrite the
|
||||
sub-question using as a search phrase for the document store>}}, ...]
|
||||
|
||||
\n\n
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
And here are the entities, relationships and terms extracted from the context:
|
||||
\n ------- \n
|
||||
{entity_term_extraction_str}
|
||||
\n ------- \n
|
||||
|
||||
Please generate the list of good, fully contextualized sub-questions that would help to address the
|
||||
main question. Don't be too specific unless the original question is specific.
|
||||
Please think through it step by step and then generate the list of json dictionaries with the following
|
||||
format:
|
||||
{{"sub_questions": [{{"sub_question": <sub-question>,
|
||||
"explanation": <explanation>,
|
||||
"search_term": <rewrite the sub-question using as a search phrase for the document store>}},
|
||||
...]}} """
|
||||
|
||||
#### Consolidations
|
||||
COMBINED_CONTEXT = """-------
|
||||
Below you will find useful information to answer the original question. First, you see a number of
|
||||
sub-questions with their answers. This information should be considered to be more focussed and
|
||||
somewhat more specific to the original question as it tries to contextualized facts.
|
||||
After that will see the documents that were considered to be relevant to answer the original question.
|
||||
|
||||
Here are the sub-questions and their answers:
|
||||
\n\n {deep_answer_context} \n\n
|
||||
\n\n Here are the documents that were considered to be relevant to answer the original question:
|
||||
\n\n {formated_docs} \n\n
|
||||
----------------
|
||||
"""
|
||||
|
||||
SUB_QUESTION_EXPLANATION_RANKER_PROMPT = """-------
|
||||
Below you will find a question that we ultimately want to answer (the original question) and a list of
|
||||
motivations in arbitrary order for generated sub-questions that are supposed to help us answering the
|
||||
original question. The motivations are formatted as <motivation number>: <motivation explanation>.
|
||||
(Again, the numbering is arbitrary and does not necessarily mean that 1 is the most relevant
|
||||
motivation and 2 is less relevant.)
|
||||
|
||||
Please rank the motivations in order of relevance for answering the original question. Also, try to
|
||||
ensure that the top questions do not duplicate too much, i.e. that they are not too similar.
|
||||
Ultimately, create a list with the motivation numbers where the number of the most relevant
|
||||
motivations comes first.
|
||||
|
||||
Here is the original question:
|
||||
\n\n {original_question} \n\n
|
||||
\n\n Here is the list of sub-question motivations:
|
||||
\n\n {sub_question_explanations} \n\n
|
||||
----------------
|
||||
|
||||
Please think step by step and then generate the ranked list of motivations.
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
{{"reasonning": <explain your reasoning for the ranking>,
|
||||
"ranked_motivations": <ranked list of motivation numbers>}}
|
||||
"""
|
||||
|
||||
|
||||
INITIAL_DECOMPOSITION_PROMPT = """ \n
|
||||
Please decompose an initial user question into 2 or 3 appropriate sub-questions that help to
|
||||
answer the original question. The purpose for this decomposition is to isolate individulal entities
|
||||
(i.e., 'compare sales of company A and company B' -> 'what are sales for company A' + 'what are sales
|
||||
for company B'), split ambiguous terms (i.e., 'what is our success with company A' -> 'what are our
|
||||
sales with company A' + 'what is our market share with company A' + 'is company A a reference customer
|
||||
for us'), etc. Each sub-question should be realistically be answerable by a good RAG system. \n
|
||||
|
||||
For each sub-question, please also create one search term that can be used to retrieve relevant
|
||||
documents from a document store.
|
||||
|
||||
Here is the initial question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
|
||||
Please formulate your answer as a list of json objects with the following format:
|
||||
|
||||
[{{"sub_question": <sub-question>, "search_term": <search term>}}, ...]
|
||||
|
||||
Answer:
|
||||
"""
|
||||
|
||||
INITIAL_RAG_PROMPT = """ \n
|
||||
You are an assistant for question-answering tasks. Use the information provided below - and only the
|
||||
provided information - to answer the provided question.
|
||||
|
||||
The information provided below consists of:
|
||||
1) a number of answered sub-questions - these are very important(!) and definitely should be
|
||||
considered to answer the question.
|
||||
2) a number of documents that were also deemed relevant for the question.
|
||||
|
||||
If you don't know the answer or if the provided information is empty or insufficient, just say
|
||||
"I don't know". Do not use your internal knowledge!
|
||||
|
||||
Again, only use the provided informationand do not use your internal knowledge! It is a matter of life
|
||||
and death that you do NOT use your internal knowledge, just the provided information!
|
||||
|
||||
Try to keep your answer concise.
|
||||
|
||||
And here is the question and the provided information:
|
||||
\n
|
||||
\nQuestion:\n {question}
|
||||
|
||||
\nAnswered Sub-questions:\n {answered_sub_questions}
|
||||
|
||||
\nContext:\n {context} \n\n
|
||||
\n\n
|
||||
|
||||
Answer:"""
|
||||
|
||||
ENTITY_TERM_PROMPT = """ \n
|
||||
Based on the original question and the context retieved from a dataset, please generate a list of
|
||||
entities (e.g. companies, organizations, industries, products, locations, etc.), terms and concepts
|
||||
(e.g. sales, revenue, etc.) that are relevant for the question, plus their relations to each other.
|
||||
|
||||
\n\n
|
||||
Here is the original question:
|
||||
\n ------- \n
|
||||
{question}
|
||||
\n ------- \n
|
||||
And here is the context retrieved:
|
||||
\n ------- \n
|
||||
{context}
|
||||
\n ------- \n
|
||||
|
||||
Please format your answer as a json object in the following format:
|
||||
|
||||
{{"retrieved_entities_relationships": {{
|
||||
"entities": [{{
|
||||
"entity_name": <assign a name for the entity>,
|
||||
"entity_type": <specify a short type name for the entity, such as 'company', 'location',...>
|
||||
}}],
|
||||
"relationships": [{{
|
||||
"name": <assign a name for the relationship>,
|
||||
"type": <specify a short type name for the relationship, such as 'sales_to', 'is_location_of',...>,
|
||||
"entities": [<related entity name 1>, <related entity name 2>]
|
||||
}}],
|
||||
"terms": [{{
|
||||
"term_name": <assign a name for the term>,
|
||||
"term_type": <specify a short type name for the term, such as 'revenue', 'market_share',...>,
|
||||
"similar_to": <list terms that are similar to this term>
|
||||
}}]
|
||||
}}
|
||||
}}
|
||||
"""
|
||||
@@ -1,101 +0,0 @@
|
||||
import ast
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.context.search.models import InferenceSection
|
||||
|
||||
|
||||
def normalize_whitespace(text: str) -> str:
|
||||
"""Normalize whitespace in text to single spaces and strip leading/trailing whitespace."""
|
||||
import re
|
||||
|
||||
return re.sub(r"\s+", " ", text.strip())
|
||||
|
||||
|
||||
# Post-processing
|
||||
def format_docs(docs: Sequence[InferenceSection]) -> str:
|
||||
return "\n\n".join(doc.combined_content for doc in docs)
|
||||
|
||||
|
||||
def clean_and_parse_list_string(json_string: str) -> list[dict]:
|
||||
# Remove any prefixes/labels before the actual JSON content
|
||||
json_string = re.sub(r"^.*?(?=\[)", "", json_string, flags=re.DOTALL)
|
||||
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
|
||||
# Try parsing with json.loads first, fall back to ast.literal_eval
|
||||
try:
|
||||
return json.loads(cleaned_string)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
return ast.literal_eval(cleaned_string)
|
||||
except (ValueError, SyntaxError) as e:
|
||||
raise ValueError(f"Failed to parse JSON string: {cleaned_string}") from e
|
||||
|
||||
|
||||
def clean_and_parse_json_string(json_string: str) -> dict[str, Any]:
|
||||
# Remove markdown code block markers and any newline prefixes
|
||||
cleaned_string = re.sub(r"```json\n|\n```", "", json_string)
|
||||
cleaned_string = cleaned_string.replace("\\n", " ").replace("\n", " ")
|
||||
cleaned_string = " ".join(cleaned_string.split())
|
||||
# Parse the cleaned string into a Python dictionary
|
||||
return json.loads(cleaned_string)
|
||||
|
||||
|
||||
def format_entity_term_extraction(entity_term_extraction_dict: dict[str, Any]) -> str:
|
||||
entities = entity_term_extraction_dict["entities"]
|
||||
terms = entity_term_extraction_dict["terms"]
|
||||
relationships = entity_term_extraction_dict["relationships"]
|
||||
|
||||
entity_strs = ["\nEntities:\n"]
|
||||
for entity in entities:
|
||||
entity_str = f"{entity['entity_name']} ({entity['entity_type']})"
|
||||
entity_strs.append(entity_str)
|
||||
|
||||
entity_str = "\n - ".join(entity_strs)
|
||||
|
||||
relationship_strs = ["\n\nRelationships:\n"]
|
||||
for relationship in relationships:
|
||||
relationship_str = f"{relationship['name']} ({relationship['type']}): {relationship['entities']}"
|
||||
relationship_strs.append(relationship_str)
|
||||
|
||||
relationship_str = "\n - ".join(relationship_strs)
|
||||
|
||||
term_strs = ["\n\nTerms:\n"]
|
||||
for term in terms:
|
||||
term_str = f"{term['term_name']} ({term['term_type']}): similar to {term['similar_to']}"
|
||||
term_strs.append(term_str)
|
||||
|
||||
term_str = "\n - ".join(term_strs)
|
||||
|
||||
return "\n".join(entity_strs + relationship_strs + term_strs)
|
||||
|
||||
|
||||
def _format_time_delta(time: timedelta) -> str:
|
||||
seconds_from_start = f"{((time).seconds):03d}"
|
||||
microseconds_from_start = f"{((time).microseconds):06d}"
|
||||
return f"{seconds_from_start}.{microseconds_from_start}"
|
||||
|
||||
|
||||
def generate_log_message(
|
||||
message: str,
|
||||
node_start_time: datetime,
|
||||
graph_start_time: datetime | None = None,
|
||||
) -> str:
|
||||
current_time = datetime.now()
|
||||
|
||||
if graph_start_time is not None:
|
||||
graph_time_str = _format_time_delta(current_time - graph_start_time)
|
||||
else:
|
||||
graph_time_str = "N/A"
|
||||
|
||||
node_time_str = _format_time_delta(current_time - node_start_time)
|
||||
|
||||
return f"{graph_time_str} ({node_time_str} s): {message}"
|
||||
@@ -87,7 +87,6 @@ from danswer.db.models import AccessToken
|
||||
from danswer.db.models import OAuthAccount
|
||||
from danswer.db.models import User
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
from danswer.utils.telemetry import RecordType
|
||||
@@ -100,6 +99,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
|
||||
@@ -11,7 +11,6 @@ from celery.exceptions import WorkerShutdown
|
||||
from celery.states import READY_STATES
|
||||
from celery.utils.log import get_task_logger
|
||||
from celery.worker import strategy # type: ignore
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sentry_sdk.integrations.celery import CeleryIntegration
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -333,16 +332,16 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
return
|
||||
|
||||
logger.info("Releasing primary worker lock.")
|
||||
lock: RedisLock = sender.primary_worker_lock
|
||||
lock = sender.primary_worker_lock
|
||||
try:
|
||||
if lock.owned():
|
||||
try:
|
||||
lock.release()
|
||||
sender.primary_worker_lock = None
|
||||
except Exception:
|
||||
logger.exception("Failed to release primary worker lock")
|
||||
except Exception:
|
||||
logger.exception("Failed to check if primary worker lock is owned")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to release primary worker lock: {e}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to check if primary worker lock is owned: {e}")
|
||||
|
||||
|
||||
def on_setup_logging(
|
||||
|
||||
@@ -11,7 +11,6 @@ from celery.signals import celeryd_init
|
||||
from celery.signals import worker_init
|
||||
from celery.signals import worker_ready
|
||||
from celery.signals import worker_shutdown
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
import danswer.background.celery.apps.app_base as app_base
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -39,6 +38,7 @@ from danswer.redis.redis_usergroup import RedisUserGroup
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
@@ -116,13 +116,9 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
# it is planned to use this lock to enforce singleton behavior on the primary
|
||||
# worker, since the primary worker does redis cleanup on startup, but this isn't
|
||||
# implemented yet.
|
||||
|
||||
# set thread_local=False since we don't control what thread the periodic task might
|
||||
# reacquire the lock with
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRIMARY_WORKER,
|
||||
timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
logger.info("Primary worker lock: Acquire starting.")
|
||||
@@ -231,7 +227,7 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
if not hasattr(worker, "primary_worker_lock"):
|
||||
return
|
||||
|
||||
lock: RedisLock = worker.primary_worker_lock
|
||||
lock = worker.primary_worker_lock
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
|
||||
@@ -2,55 +2,54 @@ from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
tasks_to_schedule = [
|
||||
{
|
||||
"name": "check-for-vespa-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
"task": "check_for_vespa_sync_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-connector-deletion",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"task": "check_for_connector_deletion_task",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-indexing",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"task": "check_for_indexing",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-prune",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
"task": "check_for_pruning",
|
||||
"schedule": timedelta(seconds=15),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "kombu-message-cleanup",
|
||||
"task": DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
"task": "kombu_message_cleanup_task",
|
||||
"schedule": timedelta(seconds=3600),
|
||||
"options": {"priority": DanswerCeleryPriority.LOWEST},
|
||||
},
|
||||
{
|
||||
"name": "monitor-vespa-sync",
|
||||
"task": DanswerCeleryTask.MONITOR_VESPA_SYNC,
|
||||
"task": "monitor_vespa_sync",
|
||||
"schedule": timedelta(seconds=5),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-doc-permissions-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
"task": "check_for_doc_permissions_sync",
|
||||
"schedule": timedelta(seconds=30),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
{
|
||||
"name": "check-for-external-group-sync",
|
||||
"task": DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
"task": "check_for_external_group_sync",
|
||||
"schedule": timedelta(seconds=20),
|
||||
"options": {"priority": DanswerCeleryPriority.HIGH},
|
||||
},
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pairs
|
||||
@@ -29,7 +28,7 @@ class TaskDependencyError(RuntimeError):
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
name="check_for_connector_deletion_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
|
||||
@@ -18,11 +18,9 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import upsert_document_by_connector_credential_pair
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
@@ -84,7 +82,7 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC,
|
||||
name="check_for_doc_permissions_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -166,7 +164,7 @@ def try_creating_permissions_sync_task(
|
||||
custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
"connector_permission_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -193,7 +191,7 @@ def try_creating_permissions_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK,
|
||||
name="connector_permission_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -263,12 +261,7 @@ def connector_permission_sync_generator_task(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
)
|
||||
tasks_generated = redis_connector.permissions.generate_tasks(
|
||||
celery_app=self.app,
|
||||
lock=lock,
|
||||
new_permissions=document_external_accesses,
|
||||
source_string=source_type,
|
||||
connector_id=cc_pair.connector.id,
|
||||
credential_id=cc_pair.credential.id,
|
||||
self.app, lock, document_external_accesses, source_type
|
||||
)
|
||||
if tasks_generated is None:
|
||||
return None
|
||||
@@ -293,7 +286,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
name="update_external_document_permissions_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES,
|
||||
@@ -304,8 +297,6 @@ def update_external_document_permissions_task(
|
||||
tenant_id: str | None,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
@@ -314,28 +305,18 @@ def update_external_document_permissions_task(
|
||||
external_access = document_external_access.external_access
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
# Add the users to the DB if they don't exist
|
||||
# Then we build the update requests to update vespa
|
||||
batch_add_ext_perm_user_if_not_exists(
|
||||
db_session=db_session,
|
||||
emails=list(external_access.external_user_emails),
|
||||
)
|
||||
# Then we upsert the document's external permissions in postgres
|
||||
created_new_doc = upsert_document_external_perms(
|
||||
upsert_document_external_perms(
|
||||
db_session=db_session,
|
||||
doc_id=doc_id,
|
||||
external_access=external_access,
|
||||
source_type=DocumentSource(source_string),
|
||||
)
|
||||
|
||||
if created_new_doc:
|
||||
# If a new document was created, we associate it with the cc_pair
|
||||
upsert_document_by_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
)
|
||||
|
||||
@@ -17,7 +17,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import mark_cc_pair_as_external_group_synced
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
@@ -32,14 +31,10 @@ from danswer.redis.redis_connector_ext_group_sync import (
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.danswer.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
from ee.danswer.db.external_perm import ExternalUserGroup
|
||||
from ee.danswer.db.external_perm import replace_user__ext_group_for_cc_pair
|
||||
from ee.danswer.external_permissions.sync_params import EXTERNAL_GROUP_SYNC_PERIODS
|
||||
from ee.danswer.external_permissions.sync_params import GROUP_PERMISSIONS_FUNC_MAP
|
||||
from ee.danswer.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -90,7 +85,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC,
|
||||
name="check_for_external_group_sync",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -111,22 +106,6 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> None:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pairs = get_all_auto_sync_cc_pairs(db_session)
|
||||
|
||||
# We only want to sync one cc_pair per source type in
|
||||
# GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC
|
||||
for source in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC:
|
||||
# These are ordered by cc_pair id so the first one is the one we want
|
||||
cc_pairs_to_dedupe = get_cc_pairs_by_source(
|
||||
db_session, source, only_sync=True
|
||||
)
|
||||
# We only want to sync one cc_pair per source type
|
||||
# in GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC so we dedupe here
|
||||
for cc_pair_to_remove in cc_pairs_to_dedupe[1:]:
|
||||
cc_pairs = [
|
||||
cc_pair
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.id != cc_pair_to_remove.id
|
||||
]
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
@@ -182,7 +161,7 @@ def try_creating_external_group_sync_task(
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
"connector_external_group_sync_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair_id,
|
||||
tenant_id=tenant_id,
|
||||
@@ -212,7 +191,7 @@ def try_creating_external_group_sync_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK,
|
||||
name="connector_external_group_sync_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
|
||||
@@ -23,7 +23,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.connector import mark_ccpair_with_indexing_trigger
|
||||
@@ -157,7 +156,7 @@ def get_unfenced_index_attempt_ids(db_session: Session, r: redis.Redis) -> list[
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
name="check_for_indexing",
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
@@ -487,7 +486,7 @@ def try_creating_indexing_task(
|
||||
# when the task is sent, we have yet to finish setting up the fence
|
||||
# therefore, the task must contain code that blocks until the fence is ready
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
"connector_indexing_proxy_task",
|
||||
kwargs=dict(
|
||||
index_attempt_id=index_attempt_id,
|
||||
cc_pair_id=cc_pair.id,
|
||||
@@ -525,10 +524,7 @@ def try_creating_indexing_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_INDEXING_PROXY_TASK,
|
||||
bind=True,
|
||||
acks_late=False,
|
||||
track_started=True,
|
||||
name="connector_indexing_proxy_task", bind=True, acks_late=False, track_started=True
|
||||
)
|
||||
def connector_indexing_proxy_task(
|
||||
self: Task,
|
||||
@@ -584,64 +580,39 @@ def connector_indexing_proxy_task(
|
||||
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
"Indexing watchdog - termination signal detected: "
|
||||
"Indexing proxy - termination signal detected: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
finally:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception marking index attempt as canceled: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
mark_attempt_canceled(
|
||||
index_attempt_id,
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
|
||||
job.cancel()
|
||||
|
||||
job.cancel()
|
||||
break
|
||||
|
||||
# do nothing for ongoing jobs that haven't been stopped
|
||||
if not job.done():
|
||||
# if the spawned task is still running, restart the check once again
|
||||
# if the index attempt is not in a finished status
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
except Exception:
|
||||
# if the DB exceptioned, just restart the check.
|
||||
# polling the index attempt status doesn't need to be strongly consistent
|
||||
logger.exception(
|
||||
"Indexing watchdog - transient exception looking up index attempt: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session=db_session, index_attempt_id=index_attempt_id
|
||||
)
|
||||
continue
|
||||
|
||||
if not index_attempt:
|
||||
continue
|
||||
|
||||
if not index_attempt.is_finished():
|
||||
continue
|
||||
|
||||
if job.status == "error":
|
||||
task_logger.error(
|
||||
"Indexing watchdog - spawned task exceptioned: "
|
||||
f"Indexing watchdog - spawned task exceptioned: "
|
||||
f"attempt={index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
@@ -789,12 +760,9 @@ def connector_indexing_task(
|
||||
)
|
||||
break
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
redis_connector_index.generator_lock_key,
|
||||
timeout=CELERY_INDEXING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -13,13 +13,12 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import PostgresAdvisoryLocks
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
name="kombu_message_cleanup_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
|
||||
@@ -8,7 +8,6 @@ from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
@@ -21,7 +20,6 @@ from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.models import InputType
|
||||
@@ -77,7 +75,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_PRUNING,
|
||||
name="check_for_pruning",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
@@ -186,7 +184,7 @@ def try_creating_prune_generator_task(
|
||||
custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}"
|
||||
|
||||
celery_app.send_task(
|
||||
DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
"connector_pruning_generator_task",
|
||||
kwargs=dict(
|
||||
cc_pair_id=cc_pair.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
@@ -211,7 +209,7 @@ def try_creating_prune_generator_task(
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK,
|
||||
name="connector_pruning_generator_task",
|
||||
acks_late=False,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
track_started=True,
|
||||
@@ -240,12 +238,9 @@ def connector_pruning_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# set thread_local=False since we don't control what thread the indexing/pruning
|
||||
# might run our callback with
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.id}",
|
||||
timeout=CELERY_PRUNING_LOCK_TIMEOUT,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
|
||||
@@ -9,7 +9,6 @@ from tenacity import RetryError
|
||||
from danswer.access.access import get_access_for_document
|
||||
from danswer.background.celery.apps.app_base import task_logger
|
||||
from danswer.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document import delete_document_by_connector_credential_pair__no_commit
|
||||
from danswer.db.document import delete_documents_complete__no_commit
|
||||
from danswer.db.document import get_document
|
||||
@@ -32,7 +31,7 @@ LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
name="document_by_cc_pair_cleanup_task",
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES,
|
||||
|
||||
@@ -25,7 +25,6 @@ from danswer.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT
|
||||
from danswer.configs.app_configs import JOB_TIMEOUT
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
from danswer.db.connector import fetch_connector_by_id
|
||||
from danswer.db.connector import mark_cc_pair_as_permissions_synced
|
||||
@@ -81,7 +80,7 @@ logger = setup_logger()
|
||||
# celery auto associates tasks created inside another task,
|
||||
# which bloats the result metadata considerably. trail=False prevents this.
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
name="check_for_vespa_sync_task",
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
@@ -655,28 +654,24 @@ def monitor_ccpair_indexing_taskset(
|
||||
# outer = result.state in READY state
|
||||
status_int = redis_connector_index.get_completion()
|
||||
if status_int is None: # inner signal not set ... possible error
|
||||
task_state = result.state
|
||||
result_state = result.state
|
||||
if (
|
||||
task_state in READY_STATES
|
||||
result_state in READY_STATES
|
||||
): # outer signal in terminal state ... possible error
|
||||
# Now double check!
|
||||
if redis_connector_index.get_completion() is None:
|
||||
# inner signal still not set (and cannot change when outer result_state is READY)
|
||||
# Task is finished but generator complete isn't set.
|
||||
# We have a problem! Worker may have crashed.
|
||||
task_result = str(result.result)
|
||||
task_traceback = str(result.traceback)
|
||||
|
||||
msg = (
|
||||
f"Connector indexing aborted or exceptioned: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"celery_task={payload.celery_task_id} "
|
||||
f"result_state={result_state} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id} "
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f} "
|
||||
f"result.state={task_state} "
|
||||
f"result.result={task_result} "
|
||||
f"result.traceback={task_traceback}"
|
||||
f"elapsed_submitted={elapsed_submitted.total_seconds():.2f}"
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
@@ -708,7 +703,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
redis_connector_index.reset()
|
||||
|
||||
|
||||
@shared_task(name=DanswerCeleryTask.MONITOR_VESPA_SYNC, soft_time_limit=300, bind=True)
|
||||
@shared_task(name="monitor_vespa_sync", soft_time_limit=300, bind=True)
|
||||
def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
"""This is a celery beat task that monitors and finalizes metadata sync tasksets.
|
||||
It scans for fence values and then gets the counts of any associated tasksets.
|
||||
@@ -819,7 +814,7 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool:
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
name="vespa_metadata_sync_task",
|
||||
bind=True,
|
||||
soft_time_limit=LIGHT_SOFT_TIME_LIMIT,
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
|
||||
@@ -2,79 +2,20 @@ import re
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi.datastructures import Headers
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import is_user_admin
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import get_chat_messages_by_session
|
||||
from danswer.db.llm import fetch_existing_doc_sets
|
||||
from danswer.db.llm import fetch_existing_tools
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
build_custom_tools_from_openapi_schema_and_headers,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def prepare_chat_message_request(
|
||||
message_text: str,
|
||||
user: User | None,
|
||||
persona_id: int | None,
|
||||
# Does the question need to have a persona override
|
||||
persona_override_config: PersonaOverrideConfig | None,
|
||||
prompt: Prompt | None,
|
||||
message_ts_to_respond_to: str | None,
|
||||
retrieval_details: RetrievalDetails | None,
|
||||
rerank_settings: RerankingDetails | None,
|
||||
db_session: Session,
|
||||
) -> CreateChatMessageRequest:
|
||||
# Typically used for one shot flows like SlackBot or non-chat API endpoint use cases
|
||||
new_chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description=None,
|
||||
user_id=user.id if user else None,
|
||||
# If using an override, this id will be ignored later on
|
||||
persona_id=persona_id or DEFAULT_PERSONA_ID,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=message_ts_to_respond_to,
|
||||
)
|
||||
|
||||
return CreateChatMessageRequest(
|
||||
chat_session_id=new_chat_session.id,
|
||||
parent_message_id=None, # It's a standalone chat session each time
|
||||
message=message_text,
|
||||
file_descriptors=[], # Currently SlackBot/answer api do not support files in the context
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
# Can always override the persona for the single query, if it's a normal persona
|
||||
# then it will be treated the same
|
||||
persona_override_config=persona_override_config,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=retrieval_details,
|
||||
rerank_settings=rerank_settings,
|
||||
)
|
||||
|
||||
|
||||
def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDoc:
|
||||
return LlmDoc(
|
||||
document_id=inference_section.center_chunk.document_id,
|
||||
@@ -90,49 +31,9 @@ def llm_doc_from_inference_section(inference_section: InferenceSection) -> LlmDo
|
||||
if inference_section.center_chunk.source_links
|
||||
else None,
|
||||
source_links=inference_section.center_chunk.source_links,
|
||||
match_highlights=inference_section.center_chunk.match_highlights,
|
||||
)
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def create_chat_chain(
|
||||
chat_session_id: UUID,
|
||||
db_session: Session,
|
||||
@@ -295,71 +196,3 @@ def extract_headers(
|
||||
if lowercase_key in headers:
|
||||
extracted_headers[lowercase_key] = headers[lowercase_key]
|
||||
return extracted_headers
|
||||
|
||||
|
||||
def create_temporary_persona(
|
||||
persona_config: PersonaOverrideConfig, db_session: Session, user: User | None = None
|
||||
) -> Persona:
|
||||
if not is_user_admin(user):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail="User is not authorized to create a persona in one shot queries",
|
||||
)
|
||||
|
||||
"""Create a temporary Persona object from the provided configuration."""
|
||||
persona = Persona(
|
||||
name=persona_config.name,
|
||||
description=persona_config.description,
|
||||
num_chunks=persona_config.num_chunks,
|
||||
llm_relevance_filter=persona_config.llm_relevance_filter,
|
||||
llm_filter_extraction=persona_config.llm_filter_extraction,
|
||||
recency_bias=persona_config.recency_bias,
|
||||
llm_model_provider_override=persona_config.llm_model_provider_override,
|
||||
llm_model_version_override=persona_config.llm_model_version_override,
|
||||
)
|
||||
|
||||
if persona_config.prompts:
|
||||
persona.prompts = [
|
||||
Prompt(
|
||||
name=p.name,
|
||||
description=p.description,
|
||||
system_prompt=p.system_prompt,
|
||||
task_prompt=p.task_prompt,
|
||||
include_citations=p.include_citations,
|
||||
datetime_aware=p.datetime_aware,
|
||||
)
|
||||
for p in persona_config.prompts
|
||||
]
|
||||
elif persona_config.prompt_ids:
|
||||
persona.prompts = get_prompts_by_ids(
|
||||
db_session=db_session, prompt_ids=persona_config.prompt_ids
|
||||
)
|
||||
|
||||
persona.tools = []
|
||||
if persona_config.custom_tools_openapi:
|
||||
for schema in persona_config.custom_tools_openapi:
|
||||
tools = cast(
|
||||
list[Tool],
|
||||
build_custom_tools_from_openapi_schema_and_headers(schema),
|
||||
)
|
||||
persona.tools.extend(tools)
|
||||
|
||||
if persona_config.tools:
|
||||
tool_ids = [tool.id for tool in persona_config.tools]
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(db_session=db_session, tool_ids=tool_ids)
|
||||
)
|
||||
|
||||
if persona_config.tool_ids:
|
||||
persona.tools.extend(
|
||||
fetch_existing_tools(
|
||||
db_session=db_session, tool_ids=persona_config.tool_ids
|
||||
)
|
||||
)
|
||||
|
||||
fetched_docs = fetch_existing_doc_sets(
|
||||
db_session=db_session, doc_ids=persona_config.document_set_ids
|
||||
)
|
||||
persona.document_sets = fetched_docs
|
||||
|
||||
return persona
|
||||
|
||||
@@ -4,14 +4,12 @@ from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.context.search.models import SearchResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
|
||||
@@ -27,7 +25,6 @@ class LlmDoc(BaseModel):
|
||||
updated_at: datetime | None
|
||||
link: str | None
|
||||
source_links: dict[int, str] | None
|
||||
match_highlights: list[str] | None
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
@@ -120,6 +117,20 @@ class StreamingError(BaseModel):
|
||||
stack_trace: str | None = None
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
class DanswerContext(BaseModel):
|
||||
content: str
|
||||
document_id: str
|
||||
@@ -135,20 +146,14 @@ class DanswerAnswer(BaseModel):
|
||||
answer: str | None
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class ChatDanswerBotResponse(BaseModel):
|
||||
answer: str | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
class QAResponse(SearchResponse, DanswerAnswer):
|
||||
quotes: list[DanswerQuote] | None
|
||||
contexts: list[DanswerContexts] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
eval_res_valid: bool | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
chat_message_id: int | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
|
||||
|
||||
class FileChatDisplay(BaseModel):
|
||||
@@ -160,41 +165,9 @@ class CustomToolResponse(BaseModel):
|
||||
tool_name: str
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PromptOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class PersonaOverrideConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptOverrideConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
AnswerQuestionPossibleReturn = (
|
||||
DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| FileChatDisplay
|
||||
|
||||
@@ -7,13 +7,10 @@ from typing import cast
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import create_temporary_persona
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
@@ -105,7 +102,6 @@ from danswer.tools.tool_implementations.internet_search.internet_search_tool imp
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
@@ -117,10 +113,8 @@ from danswer.tools.tool_implementations.search.search_tool import (
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -262,7 +256,6 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| DanswerContexts
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -293,8 +286,6 @@ def stream_chat_message_objects(
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
enforce_chat_session_id_for_search_docs: bool = True,
|
||||
bypass_acl: bool = False,
|
||||
include_contexts: bool = False,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -331,31 +322,17 @@ def stream_chat_message_objects(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session_id)}
|
||||
)
|
||||
|
||||
# use alternate persona if alternative assistant id is passed in
|
||||
if alternate_assistant_id is not None:
|
||||
# Allows users to specify a temporary persona (assistant) in the chat session
|
||||
# this takes highest priority since it's user specified
|
||||
persona = get_persona_by_id(
|
||||
alternate_assistant_id,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
is_for_edit=False,
|
||||
)
|
||||
elif new_msg_req.persona_override_config:
|
||||
# Certain endpoints allow users to specify arbitrary persona settings
|
||||
# this should never conflict with the alternate_assistant_id
|
||||
persona = persona = create_temporary_persona(
|
||||
db_session=db_session,
|
||||
persona_config=new_msg_req.persona_override_config,
|
||||
user=user,
|
||||
)
|
||||
else:
|
||||
persona = chat_session.persona
|
||||
|
||||
if not persona:
|
||||
raise RuntimeError("No persona specified or found for chat session")
|
||||
|
||||
# If a prompt override is specified via the API, use that with highest priority
|
||||
# but for saving it, we are just mapping it to an existing prompt
|
||||
prompt_id = new_msg_req.prompt_id
|
||||
if prompt_id is None and persona.prompts:
|
||||
prompt_id = sorted(persona.prompts, key=lambda x: x.id)[-1].id
|
||||
@@ -578,34 +555,19 @@ def stream_chat_message_objects(
|
||||
reserved_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
prompt_override = new_msg_req.prompt_override or chat_session.prompt_override
|
||||
if new_msg_req.persona_override_config:
|
||||
prompt_config = PromptConfig(
|
||||
system_prompt=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].system_prompt,
|
||||
task_prompt=new_msg_req.persona_override_config.prompts[0].task_prompt,
|
||||
datetime_aware=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].datetime_aware,
|
||||
include_citations=new_msg_req.persona_override_config.prompts[
|
||||
0
|
||||
].include_citations,
|
||||
)
|
||||
elif prompt_override:
|
||||
if not final_msg.prompt:
|
||||
raise ValueError(
|
||||
"Prompt override cannot be applied, no base prompt found."
|
||||
)
|
||||
prompt_config = PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=prompt_override,
|
||||
)
|
||||
elif final_msg.prompt:
|
||||
prompt_config = PromptConfig.from_model(final_msg.prompt)
|
||||
else:
|
||||
prompt_config = PromptConfig.from_model(persona.prompts[0])
|
||||
if not final_msg.prompt:
|
||||
raise RuntimeError("No Prompt found")
|
||||
|
||||
prompt_config = (
|
||||
PromptConfig.from_model(
|
||||
final_msg.prompt,
|
||||
prompt_override=(
|
||||
new_msg_req.prompt_override or chat_session.prompt_override
|
||||
),
|
||||
)
|
||||
if not persona
|
||||
else PromptConfig.from_model(persona.prompts[0])
|
||||
)
|
||||
answer_style_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig(
|
||||
all_docs_useful=selected_db_search_docs is not None
|
||||
@@ -625,13 +587,11 @@ def stream_chat_message_objects(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
rerank_settings=new_msg_req.rerank_settings,
|
||||
selected_sections=selected_sections,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
chunks_below=new_msg_req.chunks_below,
|
||||
full_doc=new_msg_req.full_doc,
|
||||
latest_query_files=latest_query_files,
|
||||
bypass_acl=bypass_acl,
|
||||
),
|
||||
internet_search_tool_config=InternetSearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
@@ -777,8 +737,6 @@ def stream_chat_message_objects(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID and include_contexts:
|
||||
yield cast(DanswerContexts, packet.response)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
pass
|
||||
@@ -887,30 +845,3 @@ def stream_chat_message(
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_slack(
|
||||
packets: ChatPacketStream,
|
||||
) -> ChatDanswerBotResponse:
|
||||
response = ChatDanswerBotResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
|
||||
@@ -308,22 +308,6 @@ CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int(
|
||||
os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000)
|
||||
)
|
||||
|
||||
# Due to breakages in the confluence API, the timezone offset must be specified client side
|
||||
# to match the user's specified timezone.
|
||||
|
||||
# The current state of affairs:
|
||||
# CQL queries are parsed in the user's timezone and cannot be specified in UTC
|
||||
# no API retrieves the user's timezone
|
||||
# All data is returned in UTC, so we can't derive the user's timezone from that
|
||||
|
||||
# https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16
|
||||
# https://jira.atlassian.com/browse/CONFCLOUD-69670
|
||||
|
||||
# enter as a floating point offset from UTC in hours (-24 < val < 24)
|
||||
# this will be applied globally, so it probably makes sense to transition this to per
|
||||
# connector as some point.
|
||||
CONFLUENCE_TIMEZONE_OFFSET = float(os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", 0.0))
|
||||
|
||||
JIRA_CONNECTOR_LABELS_TO_SKIP = [
|
||||
ignored_tag
|
||||
for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",")
|
||||
@@ -522,6 +506,3 @@ API_KEY_HASH_ROUNDS = (
|
||||
|
||||
POD_NAME = os.environ.get("POD_NAME")
|
||||
POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
|
||||
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
@@ -31,8 +31,6 @@ DISABLED_GEN_AI_MSG = (
|
||||
"You can still use Danswer as a search engine."
|
||||
)
|
||||
|
||||
DEFAULT_PERSONA_ID = 0
|
||||
|
||||
# Postgres connection constants for application_name
|
||||
POSTGRES_WEB_APP_NAME = "web"
|
||||
POSTGRES_INDEXER_APP_NAME = "indexer"
|
||||
@@ -261,32 +259,6 @@ class DanswerCeleryPriority(int, Enum):
|
||||
LOWEST = auto()
|
||||
|
||||
|
||||
class DanswerCeleryTask:
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
CHECK_FOR_INDEXING = "check_for_indexing"
|
||||
CHECK_FOR_PRUNING = "check_for_pruning"
|
||||
CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync"
|
||||
CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync"
|
||||
MONITOR_VESPA_SYNC = "monitor_vespa_sync"
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = (
|
||||
"update_external_document_permissions_task"
|
||||
)
|
||||
CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = (
|
||||
"connector_external_group_sync_generator_task"
|
||||
)
|
||||
CONNECTOR_INDEXING_PROXY_TASK = "connector_indexing_proxy_task"
|
||||
CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task"
|
||||
DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task"
|
||||
VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task"
|
||||
CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task"
|
||||
AUTOGENERATE_USAGE_REPORT_TASK = "autogenerate_usage_report_task"
|
||||
|
||||
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS = {}
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15
|
||||
REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3
|
||||
|
||||
@@ -4,8 +4,11 @@ import os
|
||||
# Danswer Slack Bot Configs
|
||||
#####
|
||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||
)
|
||||
# How much of the available input context can be used for thread context
|
||||
MAX_THREAD_CONTEXT_PERCENTAGE = 512 * 2 / 3072
|
||||
DANSWER_BOT_TARGET_CHUNK_PERCENTAGE = 512 * 2 / 3072
|
||||
# Number of docs to display in "Reference Documents"
|
||||
DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
||||
os.environ.get("DANSWER_BOT_NUM_DOCS_TO_DISPLAY", "5")
|
||||
@@ -44,6 +47,17 @@ DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL = (
|
||||
os.environ.get("DANSWER_BOT_RESPOND_EVERY_CHANNEL", "").lower() == "true"
|
||||
)
|
||||
# Add a second LLM call post Answer to verify if the Answer is valid
|
||||
# Throws out answers that don't directly or fully answer the user query
|
||||
# This is the default for all DanswerBot channels unless the channel is configured individually
|
||||
# Set/unset by "Hide Non Answers"
|
||||
ENABLE_DANSWERBOT_REFLEXION = (
|
||||
os.environ.get("ENABLE_DANSWERBOT_REFLEXION", "").lower() == "true"
|
||||
)
|
||||
# Currently not support chain of thought, probably will add back later
|
||||
DANSWER_BOT_DISABLE_COT = True
|
||||
# if set, will default DanswerBot to use quotes and reference documents
|
||||
DANSWER_BOT_USE_QUOTES = os.environ.get("DANSWER_BOT_USE_QUOTES", "").lower() == "true"
|
||||
|
||||
# Maximum Questions Per Minute, Default Uncapped
|
||||
DANSWER_BOT_MAX_QPM = int(os.environ.get("DANSWER_BOT_MAX_QPM") or 0) or None
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from urllib.parse import quote
|
||||
|
||||
from danswer.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP
|
||||
from danswer.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
|
||||
from danswer.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -71,7 +69,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
# skip it. This is generally used to avoid indexing extra sensitive
|
||||
# pages.
|
||||
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
|
||||
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
|
||||
) -> None:
|
||||
self.batch_size = batch_size
|
||||
self.continue_on_failure = continue_on_failure
|
||||
@@ -107,8 +104,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
)
|
||||
self.cql_label_filter = f" and label not in ({comma_separated_labels})"
|
||||
|
||||
self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset))
|
||||
|
||||
@property
|
||||
def confluence_client(self) -> OnyxConfluence:
|
||||
if self._confluence_client is None:
|
||||
@@ -209,14 +204,12 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
confluence_page_ids: list[str] = []
|
||||
|
||||
page_query = self.cql_page_query + self.cql_label_filter + self.cql_time_filter
|
||||
logger.debug(f"page_query: {page_query}")
|
||||
# Fetch pages as Documents
|
||||
for page in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=page_query,
|
||||
expand=",".join(_PAGE_EXPANSION_FIELDS),
|
||||
limit=self.batch_size,
|
||||
):
|
||||
logger.debug(f"_fetch_document_batches: {page['id']}")
|
||||
confluence_page_ids.append(page["id"])
|
||||
doc = self._convert_object_to_document(page)
|
||||
if doc is not None:
|
||||
@@ -249,10 +242,10 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
|
||||
def poll_source(self, start: float, end: float) -> GenerateDocumentsOutput:
|
||||
# Add time filters
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=self.timezone).strftime(
|
||||
formatted_start_time = datetime.fromtimestamp(start, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime(
|
||||
formatted_end_time = datetime.fromtimestamp(end, tz=timezone.utc).strftime(
|
||||
"%Y-%m-%d %H:%M"
|
||||
)
|
||||
self.cql_time_filter = f" and lastmodified >= '{formatted_start_time}'"
|
||||
|
||||
@@ -134,32 +134,6 @@ class OnyxConfluence(Confluence):
|
||||
super(OnyxConfluence, self).__init__(url, *args, **kwargs)
|
||||
self._wrap_methods()
|
||||
|
||||
def get_current_user(self, expand: str | None = None) -> Any:
|
||||
"""
|
||||
Implements a method that isn't in the third party client.
|
||||
|
||||
Get information about the current user
|
||||
:param expand: OPTIONAL expand for get status of user.
|
||||
Possible param is "status". Results are "Active, Deactivated"
|
||||
:return: Returns the user details
|
||||
"""
|
||||
|
||||
from atlassian.errors import ApiPermissionError # type:ignore
|
||||
|
||||
url = "rest/api/user/current"
|
||||
params = {}
|
||||
if expand:
|
||||
params["expand"] = expand
|
||||
try:
|
||||
response = self.get(url, params=params)
|
||||
except HTTPError as e:
|
||||
if e.response.status_code == 403:
|
||||
raise ApiPermissionError(
|
||||
"The calling user does not have permission", reason=e
|
||||
)
|
||||
raise
|
||||
return response
|
||||
|
||||
def _wrap_methods(self) -> None:
|
||||
"""
|
||||
For each attribute that is callable (i.e., a method) and doesn't start with an underscore,
|
||||
@@ -332,13 +306,6 @@ def _validate_connector_configuration(
|
||||
)
|
||||
spaces = confluence_client_with_minimal_retries.get_all_spaces(limit=1)
|
||||
|
||||
# uncomment the following for testing
|
||||
# the following is an attempt to retrieve the user's timezone
|
||||
# Unfornately, all data is returned in UTC regardless of the user's time zone
|
||||
# even tho CQL parses incoming times based on the user's time zone
|
||||
# space_key = spaces["results"][0]["key"]
|
||||
# space_details = confluence_client_with_minimal_retries.cql(f"space.key={space_key}+AND+type=space")
|
||||
|
||||
if not spaces:
|
||||
raise RuntimeError(
|
||||
f"No spaces found at {wiki_base}! "
|
||||
|
||||
@@ -32,11 +32,7 @@ def get_user_email_from_username__server(
|
||||
response = confluence_client.get_mobile_parameters(user_name)
|
||||
email = response.get("email")
|
||||
except Exception:
|
||||
# For now, we'll just return a string that indicates failure
|
||||
# We may want to revert to returning None in the future
|
||||
# email = None
|
||||
email = f"FAILED TO GET CONFLUENCE EMAIL FOR {user_name}"
|
||||
logger.warning(f"failed to get confluence email for {user_name}")
|
||||
email = None
|
||||
_USER_EMAIL_CACHE[user_name] = email
|
||||
return _USER_EMAIL_CACHE[user_name]
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.blocks.basic_components import MarkdownTextObject
|
||||
from slack_sdk.models.blocks.block_elements import ImageElement
|
||||
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
@@ -40,7 +40,10 @@ from danswer.danswerbot.slack.utils import translate_vespa_highlight_to_slack
|
||||
from danswer.db.chat import get_chat_session_by_message_id
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.text_processing import decode_escapes
|
||||
from danswer.utils.text_processing import replace_whitespaces_w_space
|
||||
|
||||
_MAX_BLURB_LEN = 45
|
||||
|
||||
@@ -324,7 +327,7 @@ def _build_sources_blocks(
|
||||
|
||||
|
||||
def _priority_ordered_documents_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
@@ -347,7 +350,7 @@ def _priority_ordered_documents_blocks(
|
||||
|
||||
|
||||
def _build_citations_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
) -> list[Block]:
|
||||
docs_response = answer.docs if answer.docs else None
|
||||
top_docs = docs_response.top_documents if docs_response else []
|
||||
@@ -366,8 +369,51 @@ def _build_citations_blocks(
|
||||
return citations_block
|
||||
|
||||
|
||||
def _build_quotes_block(
|
||||
quotes: list[DanswerQuote],
|
||||
) -> list[Block]:
|
||||
quote_lines: list[str] = []
|
||||
doc_to_quotes: dict[str, list[str]] = {}
|
||||
doc_to_link: dict[str, str] = {}
|
||||
doc_to_sem_id: dict[str, str] = {}
|
||||
for q in quotes:
|
||||
quote = q.quote
|
||||
doc_id = q.document_id
|
||||
doc_link = q.link
|
||||
doc_name = q.semantic_identifier
|
||||
if doc_link and doc_name and doc_id and quote:
|
||||
if doc_id not in doc_to_quotes:
|
||||
doc_to_quotes[doc_id] = [quote]
|
||||
doc_to_link[doc_id] = doc_link
|
||||
doc_to_sem_id[doc_id] = (
|
||||
doc_name
|
||||
if q.source_type != DocumentSource.SLACK.value
|
||||
else "#" + doc_name
|
||||
)
|
||||
else:
|
||||
doc_to_quotes[doc_id].append(quote)
|
||||
|
||||
for doc_id, quote_strs in doc_to_quotes.items():
|
||||
quotes_str_clean = [
|
||||
replace_whitespaces_w_space(q_str).strip() for q_str in quote_strs
|
||||
]
|
||||
longest_quotes = sorted(quotes_str_clean, key=len, reverse=True)[:5]
|
||||
single_quote_str = "\n".join([f"```{q_str}```" for q_str in longest_quotes])
|
||||
link = doc_to_link[doc_id]
|
||||
sem_id = doc_to_sem_id[doc_id]
|
||||
quote_lines.append(
|
||||
f"<{link}|{sem_id}>:\n{remove_slack_text_interactions(single_quote_str)}"
|
||||
)
|
||||
|
||||
if not doc_to_quotes:
|
||||
return []
|
||||
|
||||
return [SectionBlock(text="*Relevant Snippets*\n" + "\n".join(quote_lines))]
|
||||
|
||||
|
||||
def _build_qa_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
answer: OneShotQAResponse,
|
||||
skip_quotes: bool = False,
|
||||
process_message_for_citations: bool = False,
|
||||
) -> list[Block]:
|
||||
retrieval_info = answer.docs
|
||||
@@ -376,10 +422,13 @@ def _build_qa_response_blocks(
|
||||
raise RuntimeError("Failed to retrieve docs, cannot answer question.")
|
||||
|
||||
formatted_answer = format_slack_message(answer.answer) if answer.answer else None
|
||||
quotes = answer.quotes.quotes if answer.quotes else None
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return []
|
||||
|
||||
quotes_blocks: list[Block] = []
|
||||
|
||||
filter_block: Block | None = None
|
||||
if (
|
||||
retrieval_info.applied_time_cutoff
|
||||
@@ -422,6 +471,16 @@ def _build_qa_response_blocks(
|
||||
answer_blocks = [
|
||||
SectionBlock(text=text) for text in _split_text(answer_processed)
|
||||
]
|
||||
if quotes:
|
||||
quotes_blocks = _build_quotes_block(quotes)
|
||||
|
||||
# if no quotes OR `_build_quotes_block()` did not give back any blocks
|
||||
if not quotes_blocks:
|
||||
quotes_blocks = [
|
||||
SectionBlock(
|
||||
text="*Warning*: no sources were quoted for this answer, so it may be unreliable 😔"
|
||||
)
|
||||
]
|
||||
|
||||
response_blocks: list[Block] = []
|
||||
|
||||
@@ -430,6 +489,9 @@ def _build_qa_response_blocks(
|
||||
|
||||
response_blocks.extend(answer_blocks)
|
||||
|
||||
if not skip_quotes:
|
||||
response_blocks.extend(quotes_blocks)
|
||||
|
||||
return response_blocks
|
||||
|
||||
|
||||
@@ -505,9 +567,10 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatDanswerBotResponse,
|
||||
tenant_id: str | None,
|
||||
message_info: SlackMessageInfo,
|
||||
answer: OneShotQAResponse,
|
||||
persona: Persona | None,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
feedback_reminder_id: str | None,
|
||||
@@ -524,6 +587,7 @@ def build_slack_response_blocks(
|
||||
|
||||
answer_blocks = _build_qa_response_blocks(
|
||||
answer=answer,
|
||||
skip_quotes=persona is not None or use_citations,
|
||||
process_message_for_citations=use_citations,
|
||||
)
|
||||
|
||||
@@ -553,7 +617,8 @@ def build_slack_response_blocks(
|
||||
|
||||
citations_blocks = []
|
||||
document_blocks = []
|
||||
if use_citations and answer.citations:
|
||||
if use_citations:
|
||||
# if citations are enabled, only show cited documents
|
||||
citations_blocks = _build_citations_blocks(answer)
|
||||
else:
|
||||
document_blocks = _priority_ordered_documents_blocks(answer)
|
||||
@@ -572,5 +637,4 @@ def build_slack_response_blocks(
|
||||
+ web_follow_up_block
|
||||
+ follow_up_block
|
||||
)
|
||||
|
||||
return all_blocks
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import functools
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
@@ -8,36 +9,46 @@ from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
|
||||
from danswer.chat.chat_utils import prepare_chat_message_request
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.process_message import gather_stream_for_slack
|
||||
from danswer.chat.process_message import stream_chat_message_objects
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.constants import DEFAULT_PERSONA_ID
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_COT
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_NUM_RETRIES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import DANSWER_BOT_USE_QUOTES
|
||||
from danswer.configs.danswerbot_configs import DANSWER_FOLLOWUP_EMOJI
|
||||
from danswer.configs.danswerbot_configs import DANSWER_REACT_EMOJI
|
||||
from danswer.configs.danswerbot_configs import MAX_THREAD_CONTEXT_PERCENTAGE
|
||||
from danswer.configs.danswerbot_configs import ENABLE_DANSWERBOT_REFLEXION
|
||||
from danswer.context.search.enums import OptionalSearchSetting
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.danswerbot.slack.blocks import build_slack_response_blocks
|
||||
from danswer.danswerbot.slack.handlers.utils import send_team_member_message
|
||||
from danswer.danswerbot.slack.handlers.utils import slackify_message_thread
|
||||
from danswer.danswerbot.slack.models import SlackMessageInfo
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
from danswer.danswerbot.slack.utils import SlackRateLimiter
|
||||
from danswer.danswerbot.slack.utils import update_emote_react
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.utils.logger import DanswerLoggingAdapter
|
||||
|
||||
|
||||
srl = SlackRateLimiter()
|
||||
|
||||
RT = TypeVar("RT") # return type
|
||||
@@ -72,14 +83,16 @@ def handle_regular_answer(
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
answer_generation_timeout: int = DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
thread_context_percent: float = DANSWER_BOT_TARGET_CHUNK_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
disable_docs_only_answer: bool = DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER,
|
||||
disable_cot: bool = DANSWER_BOT_DISABLE_COT,
|
||||
reflexion: bool = ENABLE_DANSWERBOT_REFLEXION,
|
||||
) -> bool:
|
||||
channel_conf = slack_channel_config.channel_config if slack_channel_config else None
|
||||
|
||||
messages = message_info.thread_messages
|
||||
|
||||
message_ts_to_respond_to = message_info.msg_to_respond
|
||||
is_bot_msg = message_info.is_bot_msg
|
||||
user = None
|
||||
@@ -89,18 +102,9 @@ def handle_regular_answer(
|
||||
user = get_user_by_email(message_info.email, db_session)
|
||||
|
||||
document_set_names: list[str] | None = None
|
||||
prompt = None
|
||||
# If no persona is specified, use the default search based persona
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona if slack_channel_config else None
|
||||
if not persona:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
prompt = persona.prompts[0] if persona.prompts else None
|
||||
else:
|
||||
prompt = None
|
||||
if persona:
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
@@ -108,26 +112,6 @@ def handle_regular_answer(
|
||||
|
||||
should_respond_even_with_no_docs = persona.num_chunks == 0 if persona else False
|
||||
|
||||
# TODO: Add in support for Slack to truncate messages based on max LLM context
|
||||
# llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# llm_tokenizer = get_tokenizer(
|
||||
# model_name=llm.config.model_name,
|
||||
# provider_type=llm.config.model_provider,
|
||||
# )
|
||||
|
||||
# # In cases of threads, split the available tokens between docs and thread context
|
||||
# input_tokens = get_max_input_tokens(
|
||||
# model_name=llm.config.model_name,
|
||||
# model_provider=llm.config.model_provider,
|
||||
# )
|
||||
# max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
# combined_message = combine_message_thread(
|
||||
# messages, max_tokens=max_history_tokens, llm_tokenizer=llm_tokenizer
|
||||
# )
|
||||
|
||||
combined_message = slackify_message_thread(messages)
|
||||
|
||||
bypass_acl = False
|
||||
if (
|
||||
slack_channel_config
|
||||
@@ -138,6 +122,13 @@ def handle_regular_answer(
|
||||
# with non-public document sets
|
||||
bypass_acl = True
|
||||
|
||||
# figure out if we want to use citations or quotes
|
||||
use_citations = (
|
||||
not DANSWER_BOT_USE_QUOTES
|
||||
if slack_channel_config is None
|
||||
else slack_channel_config.response_type == SlackBotResponseType.CITATIONS
|
||||
)
|
||||
|
||||
if not message_ts_to_respond_to and not is_bot_msg:
|
||||
# if the message is not "/danswer" command, then it should have a message ts to respond to
|
||||
raise RuntimeError(
|
||||
@@ -150,23 +141,75 @@ def handle_regular_answer(
|
||||
backoff=2,
|
||||
)
|
||||
@rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to)
|
||||
def _get_slack_answer(
|
||||
new_message_request: CreateChatMessageRequest, danswer_user: User | None
|
||||
) -> ChatDanswerBotResponse:
|
||||
def _get_answer(new_message_request: DirectQARequest) -> OneShotQAResponse | None:
|
||||
max_document_tokens: int | None = None
|
||||
max_history_tokens: int | None = None
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
packets = stream_chat_message_objects(
|
||||
new_msg_req=new_message_request,
|
||||
user=danswer_user,
|
||||
if len(new_message_request.messages) > 1:
|
||||
if new_message_request.persona_config:
|
||||
raise RuntimeError("Slack bot does not support persona config")
|
||||
elif new_message_request.persona_id is not None:
|
||||
persona = cast(
|
||||
Persona,
|
||||
fetch_persona_by_id(
|
||||
db_session,
|
||||
new_message_request.persona_id,
|
||||
user=None,
|
||||
get_editable=False,
|
||||
),
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"No persona id provided, this should never happen."
|
||||
)
|
||||
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name,
|
||||
model_provider=llm.config.model_provider,
|
||||
)
|
||||
max_history_tokens = int(input_tokens * thread_context_percent)
|
||||
|
||||
remaining_tokens = input_tokens - max_history_tokens
|
||||
|
||||
query_text = new_message_request.messages[0].message
|
||||
if persona:
|
||||
max_document_tokens = compute_max_document_tokens_for_persona(
|
||||
persona=persona,
|
||||
actual_user_input=query_text,
|
||||
max_llm_token_override=remaining_tokens,
|
||||
)
|
||||
else:
|
||||
max_document_tokens = (
|
||||
remaining_tokens
|
||||
- 512 # Needs to be more than any of the QA prompts
|
||||
- check_number_of_tokens(query_text)
|
||||
)
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
return None
|
||||
|
||||
# This also handles creating the query event in postgres
|
||||
answer = get_search_answer(
|
||||
query_req=new_message_request,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=answer_generation_timeout,
|
||||
enable_reflexion=reflexion,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=True,
|
||||
)
|
||||
|
||||
answer = gather_stream_for_slack(packets)
|
||||
|
||||
if answer.error_msg:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
return answer
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
try:
|
||||
# By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters
|
||||
@@ -196,24 +239,26 @@ def handle_regular_answer(
|
||||
enable_auto_detect_filters=auto_detect_filters,
|
||||
)
|
||||
|
||||
# Always apply reranking settings if it exists, this is the non-streaming flow
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
answer_request = prepare_chat_message_request(
|
||||
message_text=combined_message,
|
||||
user=user,
|
||||
persona_id=persona.id,
|
||||
# This is not used in the Slack flow, only in the answer API
|
||||
persona_override_config=None,
|
||||
prompt=prompt,
|
||||
message_ts_to_respond_to=message_ts_to_respond_to,
|
||||
retrieval_details=retrieval_details,
|
||||
rerank_settings=None, # Rerank customization supported in Slack flow
|
||||
db_session=db_session,
|
||||
saved_search_settings = get_current_search_settings(db_session)
|
||||
|
||||
# This includes throwing out answer via reflexion
|
||||
answer = _get_answer(
|
||||
DirectQARequest(
|
||||
messages=messages,
|
||||
multilingual_query_expansion=saved_search_settings.multilingual_expansion
|
||||
if saved_search_settings
|
||||
else None,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
persona_id=persona.id if persona is not None else 0,
|
||||
retrieval_options=retrieval_details,
|
||||
chain_of_thought=not disable_cot,
|
||||
rerank_settings=RerankingDetails.from_db_model(saved_search_settings)
|
||||
if saved_search_settings
|
||||
else None,
|
||||
)
|
||||
|
||||
answer = _get_slack_answer(
|
||||
new_message_request=answer_request, danswer_user=user
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Unable to process message - did not successfully answer "
|
||||
@@ -314,7 +359,7 @@ def handle_regular_answer(
|
||||
top_docs = retrieval_info.top_documents
|
||||
if not top_docs and not should_respond_even_with_no_docs:
|
||||
logger.error(
|
||||
f"Unable to answer question: '{combined_message}' - no documents found"
|
||||
f"Unable to answer question: '{answer.rephrase}' - no documents found"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -335,18 +380,18 @@ def handle_regular_answer(
|
||||
)
|
||||
return True
|
||||
|
||||
only_respond_if_citations = (
|
||||
only_respond_with_citations_or_quotes = (
|
||||
channel_conf
|
||||
and "well_answered_postfilter" in channel_conf.get("answer_filters", [])
|
||||
)
|
||||
|
||||
has_citations_or_quotes = bool(answer.citations or answer.quotes)
|
||||
if (
|
||||
only_respond_if_citations
|
||||
and not answer.citations
|
||||
only_respond_with_citations_or_quotes
|
||||
and not has_citations_or_quotes
|
||||
and not message_info.bypass_filters
|
||||
):
|
||||
logger.error(
|
||||
f"Unable to find citations to answer: '{answer.answer}' - not answering!"
|
||||
f"Unable to find citations or quotes to answer: '{answer.rephrase}' - not answering!"
|
||||
)
|
||||
# Optionally, respond in thread with the error message
|
||||
# Used primarily for debugging purposes
|
||||
@@ -364,8 +409,9 @@ def handle_regular_answer(
|
||||
tenant_id=tenant_id,
|
||||
message_info=message_info,
|
||||
answer=answer,
|
||||
persona=persona,
|
||||
channel_conf=channel_conf,
|
||||
use_citations=True, # No longer supporting quotes
|
||||
use_citations=use_citations,
|
||||
feedback_reminder_id=feedback_reminder_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,33 +1,8 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.danswerbot.slack.utils import respond_in_thread
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
# Note: this does not handle extremely long threads, every message will be included
|
||||
# with weaker LLMs, this could cause issues with exceeeding the token limit
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"AI said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def send_team_member_message(
|
||||
client: WebClient,
|
||||
channel: str,
|
||||
|
||||
@@ -19,8 +19,6 @@ from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.configs.app_configs import DEV_MODE
|
||||
from danswer.configs.app_configs import POD_NAME
|
||||
from danswer.configs.app_configs import POD_NAMESPACE
|
||||
from danswer.configs.constants import DanswerRedisLocks
|
||||
@@ -76,6 +74,7 @@ from danswer.db.slack_bot import fetch_slack_bots
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.manage.models import SlackBotTokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -251,7 +250,7 @@ class SlackbotHandler:
|
||||
nx=True,
|
||||
ex=TENANT_LOCK_EXPIRATION,
|
||||
)
|
||||
if not acquired and not DEV_MODE:
|
||||
if not acquired:
|
||||
logger.debug(f"Another pod holds the lock for tenant {tenant_id}")
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import ThreadMessage
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
|
||||
|
||||
class SlackMessageInfo(BaseModel):
|
||||
|
||||
@@ -30,13 +30,13 @@ from danswer.configs.danswerbot_configs import (
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import SlackTextCleaner
|
||||
from danswer.danswerbot.slack.constants import FeedbackVisibility
|
||||
from danswer.danswerbot.slack.models import ThreadMessage
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.prompts.miscellaneous_prompts import SLACK_LANGUAGE_REPHRASE_PROMPT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.telemetry import optional_telemetry
|
||||
|
||||
@@ -145,10 +145,16 @@ def get_chat_sessions_by_user(
|
||||
user_id: UUID | None,
|
||||
deleted: bool | None,
|
||||
db_session: Session,
|
||||
only_one_shot: bool = False,
|
||||
limit: int = 50,
|
||||
) -> list[ChatSession]:
|
||||
stmt = select(ChatSession).where(ChatSession.user_id == user_id)
|
||||
|
||||
if only_one_shot:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(True))
|
||||
else:
|
||||
stmt = stmt.where(ChatSession.one_shot.is_(False))
|
||||
|
||||
stmt = stmt.order_by(desc(ChatSession.time_created))
|
||||
|
||||
if deleted is not None:
|
||||
@@ -220,11 +226,12 @@ def delete_messages_and_files_from_chat_session(
|
||||
|
||||
def create_chat_session(
|
||||
db_session: Session,
|
||||
description: str | None,
|
||||
description: str,
|
||||
user_id: UUID | None,
|
||||
persona_id: int | None, # Can be none if temporary persona is used
|
||||
llm_override: LLMOverride | None = None,
|
||||
prompt_override: PromptOverride | None = None,
|
||||
one_shot: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
slack_thread_id: str | None = None,
|
||||
) -> ChatSession:
|
||||
@@ -234,6 +241,7 @@ def create_chat_session(
|
||||
description=description,
|
||||
llm_override=llm_override,
|
||||
prompt_override=prompt_override,
|
||||
one_shot=one_shot,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
slack_thread_id=slack_thread_id,
|
||||
)
|
||||
@@ -279,6 +287,8 @@ def duplicate_chat_session_for_user_from_slack(
|
||||
description="",
|
||||
llm_override=chat_session.llm_override,
|
||||
prompt_override=chat_session.prompt_override,
|
||||
# Chat sessions from Slack should put people in the chat UI, not the search
|
||||
one_shot=False,
|
||||
# Chat is in UI now so this is false
|
||||
danswerbot_flow=False,
|
||||
# Maybe we want this in the future to track if it was created from Slack
|
||||
|
||||
@@ -37,7 +37,6 @@ from danswer.configs.app_configs import POSTGRES_PORT
|
||||
from danswer.configs.app_configs import POSTGRES_USER
|
||||
from danswer.configs.app_configs import USER_AUTH_SECRET
|
||||
from danswer.configs.constants import POSTGRES_UNKNOWN_APP_NAME
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -427,9 +426,7 @@ def get_session() -> Generator[Session, None, None]:
|
||||
"""Generate a database session with the appropriate tenant schema set."""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT:
|
||||
raise BasicAuthenticationError(
|
||||
detail="User must authenticate",
|
||||
)
|
||||
raise HTTPException(status_code=401, detail="User must authenticate")
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
from enum import Enum as PyEnum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import NotRequired
|
||||
@@ -963,8 +964,9 @@ class ChatSession(Base):
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
description: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
# This chat created by DanswerBot
|
||||
description: Mapped[str] = mapped_column(Text)
|
||||
# One-shot direct answering, currently the two types of chats are not mixed
|
||||
one_shot: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
danswerbot_flow: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# Only ever set to True if system is set to not hard-delete chats
|
||||
deleted: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
@@ -1486,6 +1488,11 @@ class ChannelConfig(TypedDict):
|
||||
show_continue_in_web_ui: NotRequired[bool] # defaults to False
|
||||
|
||||
|
||||
class SlackBotResponseType(str, PyEnum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
|
||||
|
||||
class SlackChannelConfig(Base):
|
||||
__tablename__ = "slack_channel_config"
|
||||
|
||||
@@ -1498,6 +1505,9 @@ class SlackChannelConfig(Base):
|
||||
channel_config: Mapped[ChannelConfig] = mapped_column(
|
||||
postgresql.JSONB(), nullable=False
|
||||
)
|
||||
response_type: Mapped[SlackBotResponseType] = mapped_column(
|
||||
Enum(SlackBotResponseType, native_enum=False), nullable=False
|
||||
)
|
||||
|
||||
enable_auto_filters: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
|
||||
@@ -415,6 +415,9 @@ def upsert_prompt(
|
||||
return prompt
|
||||
|
||||
|
||||
# NOTE: This operation cannot update persona configuration options that
|
||||
# are core to the persona, such as its display priority and
|
||||
# whether or not the assistant is a built-in / default assistant
|
||||
def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
@@ -446,12 +449,6 @@ def upsert_persona(
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
) -> Persona:
|
||||
"""
|
||||
NOTE: This operation cannot update persona configuration options that
|
||||
are core to the persona, such as its display priority and
|
||||
whether or not the assistant is a built-in / default assistant
|
||||
"""
|
||||
|
||||
if persona_id is not None:
|
||||
persona = db_session.query(Persona).filter_by(id=persona_id).first()
|
||||
else:
|
||||
@@ -489,8 +486,6 @@ def upsert_persona(
|
||||
validate_persona_tools(tools)
|
||||
|
||||
if persona:
|
||||
# Built-in personas can only be updated through YAML configuration.
|
||||
# This ensures that core system personas are not modified unintentionally.
|
||||
if persona.builtin_persona and not builtin_persona:
|
||||
raise ValueError("Cannot update builtin persona with non-builtin.")
|
||||
|
||||
@@ -499,9 +494,6 @@ def upsert_persona(
|
||||
db_session=db_session, persona_id=persona.id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
# The following update excludes `default`, `built-in`, and display priority.
|
||||
# Display priority is handled separately in the `display-priority` endpoint.
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
persona.name = name
|
||||
persona.description = description
|
||||
persona.num_chunks = num_chunks
|
||||
|
||||
@@ -10,6 +10,7 @@ from danswer.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Persona__DocumentSet
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_default_prompt
|
||||
@@ -82,6 +83,7 @@ def insert_slack_channel_config(
|
||||
slack_bot_id: int,
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
) -> SlackChannelConfig:
|
||||
@@ -113,6 +115,7 @@ def insert_slack_channel_config(
|
||||
slack_bot_id=slack_bot_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=response_type,
|
||||
standard_answer_categories=existing_standard_answer_categories,
|
||||
enable_auto_filters=enable_auto_filters,
|
||||
)
|
||||
@@ -127,6 +130,7 @@ def update_slack_channel_config(
|
||||
slack_channel_config_id: int,
|
||||
persona_id: int | None,
|
||||
channel_config: ChannelConfig,
|
||||
response_type: SlackBotResponseType,
|
||||
standard_answer_category_ids: list[int],
|
||||
enable_auto_filters: bool,
|
||||
) -> SlackChannelConfig:
|
||||
@@ -166,6 +170,7 @@ def update_slack_channel_config(
|
||||
# will encounter `violates foreign key constraint` errors
|
||||
slack_channel_config.persona_id = persona_id
|
||||
slack_channel_config.channel_config = channel_config
|
||||
slack_channel_config.response_type = response_type
|
||||
slack_channel_config.standard_answer_categories = list(
|
||||
existing_standard_answer_categories
|
||||
)
|
||||
|
||||
@@ -18,12 +18,18 @@ from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
QuotesResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.llm.interfaces import LLM
|
||||
@@ -208,23 +214,18 @@ class Answer:
|
||||
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
# if self.answer_style_config.citation_config:
|
||||
# answer_handler = CitationResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
# )
|
||||
# elif self.answer_style_config.quotes_config:
|
||||
# answer_handler = QuotesResponseHandler(
|
||||
# context_docs=search_result,
|
||||
# )
|
||||
# else:
|
||||
# raise ValueError("No answer style config provided")
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
answer_handler: AnswerResponseHandler
|
||||
if self.answer_style_config.citation_config:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
)
|
||||
elif self.answer_style_config.quotes_config:
|
||||
answer_handler = QuotesResponseHandler(
|
||||
context_docs=search_result,
|
||||
)
|
||||
else:
|
||||
raise ValueError("No answer style config provided")
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
|
||||
@@ -8,6 +8,7 @@ from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
@@ -29,6 +30,7 @@ if TYPE_CHECKING:
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| DanswerQuotes
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
|
||||
@@ -9,6 +9,9 @@ from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
QuotesProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -67,29 +70,28 @@ class CitationResponseHandler(AnswerResponseHandler):
|
||||
yield from self.citation_processor.process_token(content)
|
||||
|
||||
|
||||
# No longer in use, remove later
|
||||
# class QuotesResponseHandler(AnswerResponseHandler):
|
||||
# def __init__(
|
||||
# self,
|
||||
# context_docs: list[LlmDoc],
|
||||
# is_json_prompt: bool = True,
|
||||
# ):
|
||||
# self.quotes_processor = QuotesProcessor(
|
||||
# context_docs=context_docs,
|
||||
# is_json_prompt=is_json_prompt,
|
||||
# )
|
||||
class QuotesResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
is_json_prompt: bool = True,
|
||||
):
|
||||
self.quotes_processor = QuotesProcessor(
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=is_json_prompt,
|
||||
)
|
||||
|
||||
# def handle_response_part(
|
||||
# self,
|
||||
# response_item: BaseMessage | None,
|
||||
# previous_response_items: list[BaseMessage],
|
||||
# ) -> Generator[ResponsePart, None, None]:
|
||||
# if response_item is None:
|
||||
# yield from self.quotes_processor.process_token(None)
|
||||
# return
|
||||
def handle_response_part(
|
||||
self,
|
||||
response_item: BaseMessage | None,
|
||||
previous_response_items: list[BaseMessage],
|
||||
) -> Generator[ResponsePart, None, None]:
|
||||
if response_item is None:
|
||||
yield from self.quotes_processor.process_token(None)
|
||||
return
|
||||
|
||||
# content = (
|
||||
# response_item.content if isinstance(response_item.content, str) else ""
|
||||
# )
|
||||
content = (
|
||||
response_item.content if isinstance(response_item.content, str) else ""
|
||||
)
|
||||
|
||||
# yield from self.quotes_processor.process_token(content)
|
||||
yield from self.quotes_processor.process_token(content)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
# THIS IS NO LONGER IN USE
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
@@ -6,10 +5,11 @@ from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
import regex
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import DanswerAnswer
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerQuote
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.configs.chat_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
@@ -26,20 +26,6 @@ logger = setup_logger()
|
||||
answer_pattern = re.compile(r'{\s*"answer"\s*:\s*"', re.IGNORECASE)
|
||||
|
||||
|
||||
class DanswerQuote(BaseModel):
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class DanswerQuotes(BaseModel):
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
def _extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> tuple[Optional[str], Optional[list[str]]]:
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -14,11 +13,8 @@ from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _build_extra_model_kwargs(provider: str) -> dict[str, Any]:
|
||||
"""Ollama requires us to specify the max context window.
|
||||
@@ -36,15 +32,11 @@ def get_main_llm_from_tuple(
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona | PersonaOverrideConfig | None,
|
||||
persona: Persona,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> tuple[LLM, LLM]:
|
||||
if persona is None:
|
||||
logger.warning("No persona provided, using default LLMs")
|
||||
return get_default_llms()
|
||||
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
@@ -79,7 +71,6 @@ def get_llms_for_persona(
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
temperature=temperature_override,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
@@ -137,13 +128,11 @@ def get_llm(
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
custom_config: dict[str, str] | None = None,
|
||||
temperature: float | None = None,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM:
|
||||
if temperature is None:
|
||||
temperature = GEN_AI_TEMPERATURE
|
||||
return DefaultMultiLLM(
|
||||
model_provider=provider,
|
||||
model_name=model,
|
||||
|
||||
@@ -25,6 +25,7 @@ from danswer.auth.schemas import UserCreate
|
||||
from danswer.auth.schemas import UserRead
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.auth.users import auth_backend
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import create_danswer_oauth_router
|
||||
from danswer.auth.users import fastapi_users
|
||||
from danswer.configs.app_configs import APP_API_PREFIX
|
||||
@@ -91,7 +92,6 @@ from danswer.server.settings.api import basic_router as settings_router
|
||||
from danswer.server.token_rate_limits.api import (
|
||||
router as token_rate_limit_settings_router,
|
||||
)
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.setup import setup_danswer
|
||||
from danswer.setup import setup_multitenant_danswer
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -206,7 +206,7 @@ def log_http_error(_: Request, exc: Exception) -> JSONResponse:
|
||||
|
||||
if isinstance(exc, BasicAuthenticationError):
|
||||
# For BasicAuthenticationError, just log a brief message without stack trace (almost always spam)
|
||||
logger.warning(f"Authentication failed: {str(exc)}")
|
||||
logger.error(f"Authentication failed: {str(exc)}")
|
||||
|
||||
elif status_code >= 400:
|
||||
error_msg = f"{str(exc)}\n"
|
||||
|
||||
456
backend/danswer/one_shot_answer/answer_question.py
Normal file
456
backend/danswer/one_shot_answer/answer_question.py
Normal file
@@ -0,0 +1,456 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import reorganize_citations
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import DocumentRelevance
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import RelevanceAnalysis
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.models import RerankMetricsContainer
|
||||
from danswer.context.search.models import RetrievalMetricsContainer
|
||||
from danswer.context.search.utils import chunks_or_sections_to_search_docs
|
||||
from danswer.context.search.utils import dedupe_documents
|
||||
from danswer.db.chat import create_chat_session
|
||||
from danswer.db.chat import create_db_search_doc
|
||||
from danswer.db.chat import create_new_chat_message
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.chat import update_search_docs_table_with_relevance
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompt_by_id
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
from danswer.one_shot_answer.models import QueryRephrase
|
||||
from danswer.one_shot_answer.qa_utils import combine_message_thread
|
||||
from danswer.one_shot_answer.qa_utils import slackify_message_thread
|
||||
from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.query_expansion import thread_based_query_rephrase
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import (
|
||||
SECTION_RELEVANCE_LIST_ID,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
AnswerObjectIterator = Iterator[
|
||||
QueryRephrase
|
||||
| QADocsResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| DanswerContexts
|
||||
| StreamingError
|
||||
| ChatMessageDetail
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| DocumentRelevance
|
||||
]
|
||||
|
||||
|
||||
def stream_answer_objects(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
# These need to be passed in because in Web UI one shot flow,
|
||||
# we can have much more document as there is no history.
|
||||
# For Slack flow, we need to save more tokens for the thread context
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
# Needed to translate persona num_chunks to tokens to the LLM
|
||||
default_num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> AnswerObjectIterator:
|
||||
"""Streams in order:
|
||||
1. [always] Retrieved documents, stops flow if nothing is found
|
||||
2. [conditional] LLM selected chunk indices if LLM chunk filtering is turned on
|
||||
3. [always] A set of streamed DanswerAnswerPiece and DanswerQuotes at the end
|
||||
or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
user_id = user.id if user is not None else None
|
||||
query_msg = query_req.messages[-1]
|
||||
history = query_req.messages[:-1]
|
||||
|
||||
chat_session = create_chat_session(
|
||||
db_session=db_session,
|
||||
description="", # One shot queries don't need naming as it's never displayed
|
||||
user_id=user_id,
|
||||
persona_id=query_req.persona_id,
|
||||
one_shot=True,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
)
|
||||
|
||||
# permanent "log" store, used primarily for debugging
|
||||
long_term_logger = LongTermLogger(
|
||||
metadata={"user_id": str(user_id), "chat_session_id": str(chat_session.id)}
|
||||
)
|
||||
|
||||
temporary_persona: Persona | None = None
|
||||
|
||||
if query_req.persona_config is not None:
|
||||
temporary_persona = fetch_ee_implementation_or_noop(
|
||||
"danswer.server.query_and_chat.utils", "create_temporary_persona", None
|
||||
)(db_session=db_session, persona_config=query_req.persona_config, user=user)
|
||||
|
||||
persona = temporary_persona if temporary_persona else chat_session.persona
|
||||
|
||||
try:
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona, long_term_logger=long_term_logger
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(
|
||||
f"Failed to initialize LLMs for persona '{persona.name}': {str(e)}"
|
||||
)
|
||||
if "No LLM provider" in str(e):
|
||||
raise ValueError(
|
||||
"Please configure a Generative AI model to use this feature."
|
||||
) from e
|
||||
raise ValueError(
|
||||
"Failed to initialize the AI model. Please check your configuration and try again."
|
||||
) from e
|
||||
|
||||
llm_tokenizer = get_tokenizer(
|
||||
model_name=llm.config.model_name,
|
||||
provider_type=llm.config.model_provider,
|
||||
)
|
||||
|
||||
# Create a chat session which will just store the root message, the query, and the AI response
|
||||
root_message = get_or_create_root_message(
|
||||
chat_session_id=chat_session.id, db_session=db_session
|
||||
)
|
||||
|
||||
history_str = combine_message_thread(
|
||||
messages=history,
|
||||
max_tokens=max_history_tokens,
|
||||
llm_tokenizer=llm_tokenizer,
|
||||
)
|
||||
|
||||
rephrased_query = query_req.query_override or thread_based_query_rephrase(
|
||||
user_query=query_msg.message,
|
||||
history_str=history_str,
|
||||
)
|
||||
|
||||
# Given back ahead of the documents for latency reasons
|
||||
# In chat flow it's given back along with the documents
|
||||
yield QueryRephrase(rephrased_query=rephrased_query)
|
||||
|
||||
prompt = None
|
||||
if query_req.prompt_id is not None:
|
||||
# NOTE: let the user access any prompt as long as the Persona is shared
|
||||
# with them
|
||||
prompt = get_prompt_by_id(
|
||||
prompt_id=query_req.prompt_id, user=None, db_session=db_session
|
||||
)
|
||||
if prompt is None:
|
||||
if not persona.prompts:
|
||||
raise RuntimeError(
|
||||
"Persona does not have any prompts - this should never happen"
|
||||
)
|
||||
prompt = persona.prompts[0]
|
||||
|
||||
user_message_str = query_msg.message
|
||||
# For this endpoint, we only save one user message to the chat session
|
||||
# However, for slackbot, we want to include the history of the entire thread
|
||||
if danswerbot_flow:
|
||||
# Right now, we only support bringing over citations and search docs
|
||||
# from the last message in the thread, not the entire thread
|
||||
# in the future, we may want to retrieve the entire thread
|
||||
user_message_str = slackify_message_thread(query_req.messages)
|
||||
|
||||
# Create the first User query message
|
||||
new_user_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=root_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=user_message_str,
|
||||
token_count=len(llm_tokenizer.encode(user_message_str)),
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
persona.num_chunks if persona.num_chunks is not None else default_num_chunks
|
||||
),
|
||||
max_tokens=max_document_tokens,
|
||||
)
|
||||
|
||||
answer_config = AnswerStyleConfig(
|
||||
citation_config=CitationConfig() if use_citations else None,
|
||||
quotes_config=QuotesConfig() if not use_citations else None,
|
||||
document_pruning_config=document_pruning_config,
|
||||
)
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.SKIP
|
||||
if DISABLE_LLM_DOC_RELEVANCE
|
||||
else query_req.evaluation_type
|
||||
),
|
||||
persona=persona,
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
answer_style_config=answer_config,
|
||||
bypass_acl=bypass_acl,
|
||||
chunks_above=query_req.chunks_above,
|
||||
chunks_below=query_req.chunks_below,
|
||||
full_doc=query_req.full_doc,
|
||||
)
|
||||
|
||||
answer = Answer(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_main_llm_from_tuple(
|
||||
get_llms_for_persona(persona=persona, long_term_logger=long_term_logger)
|
||||
),
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool] if search_tool else [],
|
||||
force_use_tool=(
|
||||
ForceUseTool(
|
||||
tool_name=search_tool.name,
|
||||
args={"query": rephrased_query},
|
||||
force_use=True,
|
||||
)
|
||||
),
|
||||
# for now, don't use tool calling for this flow, as we haven't
|
||||
# tested quotes with tool calling too much yet
|
||||
skip_explicit_tool_calling=True,
|
||||
return_contexts=query_req.return_contexts,
|
||||
skip_gen_ai_answer_generation=query_req.skip_gen_ai_answer_generation,
|
||||
)
|
||||
# won't be any FileChatDisplay responses since that tool is never passed in
|
||||
for packet in cast(AnswerObjectIterator, answer.processed_streamed_output):
|
||||
# for one-shot flow, don't currently do anything with these
|
||||
if isinstance(packet, ToolResponse):
|
||||
# (likely fine that it comes after the initial creation of the search docs)
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, packet.response)
|
||||
|
||||
top_docs = chunks_or_sections_to_search_docs(
|
||||
search_response_summary.top_sections
|
||||
)
|
||||
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
deduped_docs = top_docs
|
||||
if query_req.retrieval_options.dedupe_docs:
|
||||
deduped_docs, dropped_inds = dedupe_documents(top_docs)
|
||||
|
||||
reference_db_search_docs = [
|
||||
create_db_search_doc(server_search_doc=doc, db_session=db_session)
|
||||
for doc in deduped_docs
|
||||
]
|
||||
|
||||
response_docs = [
|
||||
translate_db_search_doc_to_server_search_doc(db_search_doc)
|
||||
for db_search_doc in reference_db_search_docs
|
||||
]
|
||||
|
||||
initial_response = QADocsResponse(
|
||||
rephrased_query=rephrased_query,
|
||||
top_documents=response_docs,
|
||||
predicted_flow=search_response_summary.predicted_flow,
|
||||
predicted_search=search_response_summary.predicted_search,
|
||||
applied_source_filters=search_response_summary.final_filters.source_type,
|
||||
applied_time_cutoff=search_response_summary.final_filters.time_cutoff,
|
||||
recency_bias_multiplier=search_response_summary.recency_bias_multiplier,
|
||||
)
|
||||
|
||||
yield initial_response
|
||||
|
||||
elif packet.id == SEARCH_DOC_CONTENT_ID:
|
||||
yield packet.response
|
||||
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
document_based_response = {}
|
||||
|
||||
if packet.response is not None:
|
||||
for evaluation in packet.response:
|
||||
document_based_response[
|
||||
evaluation.document_id
|
||||
] = RelevanceAnalysis(
|
||||
relevant=evaluation.relevant, content=evaluation.content
|
||||
)
|
||||
|
||||
evaluation_response = DocumentRelevance(
|
||||
relevance_summaries=document_based_response
|
||||
)
|
||||
if reference_db_search_docs is not None:
|
||||
update_search_docs_table_with_relevance(
|
||||
db_session=db_session,
|
||||
reference_db_search_docs=reference_db_search_docs,
|
||||
relevance_summary=evaluation_response,
|
||||
)
|
||||
yield evaluation_response
|
||||
|
||||
else:
|
||||
yield packet
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
gen_ai_response_message = create_new_chat_message(
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=new_user_message,
|
||||
prompt_id=query_req.prompt_id,
|
||||
message=answer.llm_answer,
|
||||
token_count=len(llm_tokenizer.encode(answer.llm_answer)),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
error=None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
yield msg_detail_response
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_search_answer(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as session:
|
||||
objects = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=session,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.model_dump())
|
||||
|
||||
|
||||
def get_search_answer(
|
||||
query_req: DirectQARequest,
|
||||
user: User | None,
|
||||
max_document_tokens: int | None,
|
||||
max_history_tokens: int | None,
|
||||
db_session: Session,
|
||||
answer_generation_timeout: int = QA_TIMEOUT,
|
||||
enable_reflexion: bool = False,
|
||||
bypass_acl: bool = False,
|
||||
use_citations: bool = False,
|
||||
danswerbot_flow: bool = False,
|
||||
retrieval_metrics_callback: (
|
||||
Callable[[RetrievalMetricsContainer], None] | None
|
||||
) = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> OneShotQAResponse:
|
||||
"""Collects the streamed one shot answer responses into a single object"""
|
||||
qa_response = OneShotQAResponse()
|
||||
|
||||
results = stream_answer_objects(
|
||||
query_req=query_req,
|
||||
user=user,
|
||||
max_document_tokens=max_document_tokens,
|
||||
max_history_tokens=max_history_tokens,
|
||||
db_session=db_session,
|
||||
bypass_acl=bypass_acl,
|
||||
use_citations=use_citations,
|
||||
danswerbot_flow=danswerbot_flow,
|
||||
timeout=answer_generation_timeout,
|
||||
retrieval_metrics_callback=retrieval_metrics_callback,
|
||||
rerank_metrics_callback=rerank_metrics_callback,
|
||||
)
|
||||
|
||||
answer = ""
|
||||
for packet in results:
|
||||
if isinstance(packet, QueryRephrase):
|
||||
qa_response.rephrase = packet.rephrased_query
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
qa_response.docs = packet
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
qa_response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, DanswerQuotes):
|
||||
qa_response.quotes = packet
|
||||
elif isinstance(packet, CitationInfo):
|
||||
if qa_response.citations:
|
||||
qa_response.citations.append(packet)
|
||||
else:
|
||||
qa_response.citations = [packet]
|
||||
elif isinstance(packet, DanswerContexts):
|
||||
qa_response.contexts = packet
|
||||
elif isinstance(packet, StreamingError):
|
||||
qa_response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
qa_response.chat_message_id = packet.message_id
|
||||
|
||||
if answer:
|
||||
qa_response.answer = answer
|
||||
|
||||
if enable_reflexion:
|
||||
# Because follow up messages are explicitly tagged, we don't need to verify the answer
|
||||
if len(query_req.messages) == 1:
|
||||
first_query = query_req.messages[0].message
|
||||
qa_response.answer_valid = get_answer_validity(first_query, answer)
|
||||
else:
|
||||
qa_response.answer_valid = True
|
||||
|
||||
if use_citations and qa_response.answer and qa_response.citations:
|
||||
# Reorganize citation nums to be in the same order as the answer
|
||||
qa_response.answer, qa_response.citations = reorganize_citations(
|
||||
qa_response.answer, qa_response.citations
|
||||
)
|
||||
|
||||
return qa_response
|
||||
114
backend/danswer/one_shot_answer/models.py
Normal file
114
backend/danswer/one_shot_answer/models.py
Normal file
@@ -0,0 +1,114 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import ChunkContext
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
|
||||
|
||||
class QueryRephrase(BaseModel):
|
||||
rephrased_query: str
|
||||
|
||||
|
||||
class ThreadMessage(BaseModel):
|
||||
message: str
|
||||
sender: str | None = None
|
||||
role: MessageType = MessageType.USER
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
name: str
|
||||
description: str = ""
|
||||
system_prompt: str
|
||||
task_prompt: str = ""
|
||||
include_citations: bool = True
|
||||
datetime_aware: bool = True
|
||||
|
||||
|
||||
class ToolConfig(BaseModel):
|
||||
id: int
|
||||
|
||||
|
||||
class PersonaConfig(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
num_chunks: float | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
|
||||
prompts: list[PromptConfig] = Field(default_factory=list)
|
||||
prompt_ids: list[int] = Field(default_factory=list)
|
||||
|
||||
document_set_ids: list[int] = Field(default_factory=list)
|
||||
tools: list[ToolConfig] = Field(default_factory=list)
|
||||
tool_ids: list[int] = Field(default_factory=list)
|
||||
custom_tools_openapi: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class DirectQARequest(ChunkContext):
|
||||
persona_config: PersonaConfig | None = None
|
||||
persona_id: int | None = None
|
||||
|
||||
messages: list[ThreadMessage]
|
||||
prompt_id: int | None = None
|
||||
multilingual_query_expansion: list[str] | None = None
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
|
||||
chain_of_thought: bool = False
|
||||
return_contexts: bool = False
|
||||
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# can be used if the message sent to the LLM / query should not be the same
|
||||
# will also disable Thread-based Rewording if specified
|
||||
query_override: str | None = None
|
||||
|
||||
# If True, skips generative an AI response to the search query
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_persona_fields(self) -> "DirectQARequest":
|
||||
if (self.persona_config is None) == (self.persona_id is None):
|
||||
raise ValueError("Exactly one of persona_config or persona_id must be set")
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_chain_of_thought_and_prompt_id(self) -> "DirectQARequest":
|
||||
if self.chain_of_thought and self.prompt_id is not None:
|
||||
raise ValueError(
|
||||
"If chain_of_thought is True, prompt_id must be None"
|
||||
"The chain of thought prompt is only for question "
|
||||
"answering and does not accept customizing."
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class OneShotQAResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
rephrase: str | None = None
|
||||
quotes: DanswerQuotes | None = None
|
||||
citations: list[CitationInfo] | None = None
|
||||
docs: QADocsResponse | None = None
|
||||
llm_selected_doc_indices: list[int] | None = None
|
||||
error_msg: str | None = None
|
||||
answer_valid: bool = True # Reflexion result, default True if Reflexion not run
|
||||
chat_message_id: int | None = None
|
||||
contexts: DanswerContexts | None = None
|
||||
81
backend/danswer/one_shot_answer/qa_utils.py
Normal file
81
backend/danswer/one_shot_answer/qa_utils.py
Normal file
@@ -0,0 +1,81 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
"""Mock streaming by generating the passed in model output, character by character"""
|
||||
for token in model_out:
|
||||
yield token
|
||||
|
||||
|
||||
def combine_message_thread(
|
||||
messages: list[ThreadMessage],
|
||||
max_tokens: int | None,
|
||||
llm_tokenizer: BaseTokenizer,
|
||||
) -> str:
|
||||
"""Used to create a single combined message context from threads"""
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
total_token_count = 0
|
||||
|
||||
for message in reversed(messages):
|
||||
if message.role == MessageType.USER:
|
||||
role_str = message.role.value.upper()
|
||||
if message.sender:
|
||||
role_str += " " + message.sender
|
||||
else:
|
||||
# Since other messages might have the user identifying information
|
||||
# better to use Unknown for symmetry
|
||||
role_str += " Unknown"
|
||||
else:
|
||||
role_str = message.role.value.upper()
|
||||
|
||||
msg_str = f"{role_str}:\n{message.message}"
|
||||
message_token_count = len(llm_tokenizer.encode(msg_str))
|
||||
|
||||
if (
|
||||
max_tokens is not None
|
||||
and total_token_count + message_token_count > max_tokens
|
||||
):
|
||||
break
|
||||
|
||||
message_strs.insert(0, msg_str)
|
||||
total_token_count += message_token_count
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
|
||||
|
||||
def slackify_message(message: ThreadMessage) -> str:
|
||||
if message.role != MessageType.USER:
|
||||
return message.message
|
||||
|
||||
return f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
|
||||
|
||||
def slackify_message_thread(messages: list[ThreadMessage]) -> str:
|
||||
if not messages:
|
||||
return ""
|
||||
|
||||
message_strs: list[str] = []
|
||||
for message in messages:
|
||||
if message.role == MessageType.USER:
|
||||
message_text = (
|
||||
f"{message.sender or 'Unknown User'} said in Slack:\n{message.message}"
|
||||
)
|
||||
elif message.role == MessageType.ASSISTANT:
|
||||
message_text = f"DanswerBot said in Slack:\n{message.message}"
|
||||
else:
|
||||
message_text = (
|
||||
f"{message.role.value.upper()} said in Slack:\n{message.message}"
|
||||
)
|
||||
message_strs.append(message_text)
|
||||
|
||||
return "\n\n".join(message_strs)
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import (
|
||||
construct_document_select_for_connector_credential_pair_by_needs_sync,
|
||||
@@ -106,7 +105,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -12,7 +12,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from danswer.db.document import construct_document_select_for_connector_credential_pair
|
||||
from danswer.db.models import Document as DbDocument
|
||||
@@ -115,7 +114,7 @@ class RedisConnectorDelete:
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc.id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -12,7 +12,6 @@ from danswer.access.models import DocExternalAccess
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
@@ -133,8 +132,6 @@ class RedisConnectorPermissionSync:
|
||||
lock: RedisLock | None,
|
||||
new_permissions: list[DocExternalAccess],
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> int | None:
|
||||
last_lock_time = time.monotonic()
|
||||
async_results = []
|
||||
@@ -152,13 +149,11 @@ class RedisConnectorPermissionSync:
|
||||
self.redis.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK,
|
||||
"update_external_document_permissions_task",
|
||||
kwargs=dict(
|
||||
tenant_id=self.tenant_id,
|
||||
serialized_doc_external_access=doc_perm.to_dict(),
|
||||
source_string=source_string,
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
),
|
||||
queue=DanswerCeleryQueues.DOC_PERMISSIONS_UPSERT,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -10,7 +10,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
|
||||
|
||||
@@ -135,7 +134,7 @@ class RedisConnectorPrune:
|
||||
|
||||
# Priority on sync's triggered by new indexing should be medium
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK,
|
||||
"document_by_cc_pair_cleanup_task",
|
||||
kwargs=dict(
|
||||
document_id=doc_id,
|
||||
connector_id=cc_pair.connector_id,
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.db.document_set import construct_document_select_by_docset
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
|
||||
@@ -77,7 +76,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
|
||||
from danswer.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryQueues
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.redis.redis_object_helper import RedisObjectHelper
|
||||
from danswer.utils.variable_functionality import fetch_versioned_implementation
|
||||
from danswer.utils.variable_functionality import global_version
|
||||
@@ -90,7 +89,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
redis_client.sadd(self.taskset_key, custom_task_id)
|
||||
|
||||
result = celery_app.send_task(
|
||||
DanswerCeleryTask.VESPA_METADATA_SYNC_TASK,
|
||||
"vespa_metadata_sync_task",
|
||||
kwargs=dict(document_id=doc.id, tenant_id=tenant_id),
|
||||
queue=DanswerCeleryQueues.VESPA_METADATA_SYNC,
|
||||
task_id=custom_task_id,
|
||||
|
||||
@@ -20,7 +20,6 @@ from danswer.background.celery.celery_utils import get_deletion_attempt_snapshot
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.connectors.google_utils.google_auth import (
|
||||
@@ -868,7 +867,7 @@ def connector_run_once(
|
||||
|
||||
# run the beat task to pick up the triggers immediately
|
||||
primary_app.send_task(
|
||||
DanswerCeleryTask.CHECK_FOR_INDEXING,
|
||||
"check_for_indexing",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.background.celery.versioned_apps.primary import app as primary_app
|
||||
from danswer.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from danswer.configs.constants import DanswerCeleryPriority
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
|
||||
from danswer.db.connector_credential_pair import get_connector_credential_pair
|
||||
@@ -200,7 +199,7 @@ def create_deletion_attempt_for_connector_id(
|
||||
|
||||
# run the beat task to pick up this deletion from the db immediately
|
||||
primary_app.send_task(
|
||||
DanswerCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
"check_for_connector_deletion_task",
|
||||
priority=DanswerCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -16,6 +15,7 @@ from danswer.danswerbot.slack.config import VALID_SLACK_FILTERS
|
||||
from danswer.db.models import AllowedAnswerFilters
|
||||
from danswer.db.models import ChannelConfig
|
||||
from danswer.db.models import SlackBot as SlackAppModel
|
||||
from danswer.db.models import SlackBotResponseType
|
||||
from danswer.db.models import SlackChannelConfig as SlackChannelConfigModel
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.persona.models import PersonaSnapshot
|
||||
@@ -148,12 +148,6 @@ class SlackBotTokens(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
# TODO No longer in use, remove later
|
||||
class SlackBotResponseType(str, Enum):
|
||||
QUOTES = "quotes"
|
||||
CITATIONS = "citations"
|
||||
|
||||
|
||||
class SlackChannelConfigCreationRequest(BaseModel):
|
||||
slack_bot_id: int
|
||||
# currently, a persona is created for each Slack channel config
|
||||
@@ -203,6 +197,7 @@ class SlackChannelConfig(BaseModel):
|
||||
id: int
|
||||
persona: PersonaSnapshot | None
|
||||
channel_config: ChannelConfig
|
||||
response_type: SlackBotResponseType
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories: list[StandardAnswerCategory]
|
||||
enable_auto_filters: bool
|
||||
@@ -222,6 +217,7 @@ class SlackChannelConfig(BaseModel):
|
||||
else None
|
||||
),
|
||||
channel_config=slack_channel_config_model.channel_config,
|
||||
response_type=slack_channel_config_model.response_type,
|
||||
# XXX this is going away soon
|
||||
standard_answer_categories=[
|
||||
StandardAnswerCategory.from_model(standard_answer_category_model)
|
||||
|
||||
@@ -118,6 +118,7 @@ def create_slack_channel_config(
|
||||
slack_bot_id=slack_channel_config_creation_request.slack_bot_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_channel_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
|
||||
db_session=db_session,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
@@ -181,6 +182,7 @@ def patch_slack_channel_config(
|
||||
slack_channel_config_id=slack_channel_config_id,
|
||||
persona_id=persona_id,
|
||||
channel_config=channel_config,
|
||||
response_type=slack_channel_config_creation_request.response_type,
|
||||
standard_answer_category_ids=slack_channel_config_creation_request.standard_answer_categories,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
)
|
||||
|
||||
@@ -26,6 +26,7 @@ from danswer.auth.noauth_user import fetch_no_auth_user
|
||||
from danswer.auth.noauth_user import set_no_auth_user_preferences
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserStatus
|
||||
from danswer.auth.users import BasicAuthenticationError
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
@@ -59,7 +60,6 @@ from danswer.server.manage.models import UserRoleUpdateRequest
|
||||
from danswer.server.models import FullUserSnapshot
|
||||
from danswer.server.models import InvitedUserSnapshot
|
||||
from danswer.server.models import MinimalUserSnapshot
|
||||
from danswer.server.utils import BasicAuthenticationError
|
||||
from danswer.server.utils import send_user_email_invite
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
@@ -109,7 +109,6 @@ def process_run_in_background(
|
||||
prompt_id=chat_session.persona.prompts[0].id,
|
||||
search_doc_ids=None,
|
||||
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
||||
rerank_settings=None,
|
||||
query_override=None,
|
||||
regenerate=None,
|
||||
llm_override=None,
|
||||
|
||||
@@ -5,14 +5,12 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import PersonaOverrideConfig
|
||||
from danswer.chat.models import RetrievalDocs
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.context.search.models import BaseFilters
|
||||
from danswer.context.search.models import ChunkContext
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.context.search.models import Tag
|
||||
@@ -89,8 +87,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# If search_doc_ids provided, then retrieval options are unused
|
||||
search_doc_ids: list[int] | None
|
||||
retrieval_options: RetrievalDetails | None
|
||||
# Useable via the APIs but not recommended for most flows
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
@@ -106,10 +102,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# allow user to specify an alternate assistnat
|
||||
alternate_assistant_id: int | None = None
|
||||
|
||||
# This takes the priority over the prompt_override
|
||||
# This won't be a type that's passed in directly from the API
|
||||
persona_override_config: PersonaOverrideConfig | None = None
|
||||
|
||||
# used for seeded chats to kick off the generation of an AI answer
|
||||
use_existing_user_message: bool = False
|
||||
|
||||
@@ -153,7 +145,7 @@ class RenameChatSessionResponse(BaseModel):
|
||||
|
||||
class ChatSessionDetails(BaseModel):
|
||||
id: UUID
|
||||
name: str | None
|
||||
name: str
|
||||
persona_id: int | None = None
|
||||
time_created: str
|
||||
shared_status: ChatSessionSharedStatus
|
||||
@@ -206,14 +198,14 @@ class ChatMessageDetail(BaseModel):
|
||||
|
||||
class SearchSessionDetailResponse(BaseModel):
|
||||
search_session_id: UUID
|
||||
description: str | None
|
||||
description: str
|
||||
documents: list[SearchDoc]
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
|
||||
class ChatSessionDetailResponse(BaseModel):
|
||||
chat_session_id: UUID
|
||||
description: str | None
|
||||
description: str
|
||||
persona_id: int | None = None
|
||||
persona_name: str | None
|
||||
messages: list[ChatMessageDetail]
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -28,6 +32,8 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.tag import find_tags
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.one_shot_answer.answer_question import stream_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.server.query_and_chat.models import AdminSearchRequest
|
||||
from danswer.server.query_and_chat.models import AdminSearchResponse
|
||||
from danswer.server.query_and_chat.models import ChatSessionDetails
|
||||
@@ -35,6 +41,7 @@ from danswer.server.query_and_chat.models import ChatSessionsResponse
|
||||
from danswer.server.query_and_chat.models import SearchSessionDetailResponse
|
||||
from danswer.server.query_and_chat.models import SourceTag
|
||||
from danswer.server.query_and_chat.models import TagResponse
|
||||
from danswer.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -133,7 +140,7 @@ def get_user_search_sessions(
|
||||
|
||||
try:
|
||||
search_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session
|
||||
user_id=user_id, deleted=False, db_session=db_session, only_one_shot=True
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
@@ -222,3 +229,29 @@ def get_search_session(
|
||||
],
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@basic_router.post("/stream-answer-with-quote")
|
||||
def get_answer_with_quote(
|
||||
query_request: DirectQARequest,
|
||||
user: User = Depends(current_limited_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
) -> StreamingResponse:
|
||||
query = query_request.messages[0].message
|
||||
|
||||
logger.notice(f"Received query for one shot answer with quotes: {query}")
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in stream_search_answer(
|
||||
query_req=query_request,
|
||||
user=user,
|
||||
max_document_tokens=None,
|
||||
max_history_tokens=0,
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
except Exception as e:
|
||||
logger.exception("Error in search answer streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="application/json")
|
||||
|
||||
@@ -6,9 +6,6 @@ from email.mime.text import MIMEText
|
||||
from textwrap import dedent
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from danswer.configs.app_configs import SMTP_PASS
|
||||
from danswer.configs.app_configs import SMTP_PORT
|
||||
from danswer.configs.app_configs import SMTP_SERVER
|
||||
@@ -17,11 +14,6 @@ from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.db.models import User
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
|
||||
|
||||
|
||||
@@ -25,9 +25,6 @@ class ToolCallSummary(BaseModel__v1):
|
||||
tool_call_request: AIMessage
|
||||
tool_call_result: ToolMessage
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def tool_call_tokens(
|
||||
tool_call_summary: ToolCallSummary, llm_tokenizer: BaseTokenizer
|
||||
|
||||
@@ -13,7 +13,6 @@ from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
@@ -103,14 +102,11 @@ class SearchToolConfig(BaseModel):
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
retrieval_options: RetrievalDetails = Field(default_factory=RetrievalDetails)
|
||||
rerank_settings: RerankingDetails | None = None
|
||||
selected_sections: list[InferenceSection] | None = None
|
||||
chunks_above: int = 0
|
||||
chunks_below: int = 0
|
||||
full_doc: bool = False
|
||||
latest_query_files: list[InMemoryChatFile] | None = None
|
||||
# Use with care, should only be used for DanswerBot in channels with multiple users
|
||||
bypass_acl: bool = False
|
||||
|
||||
|
||||
class InternetSearchToolConfig(BaseModel):
|
||||
@@ -174,8 +170,6 @@ def construct_tools(
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
|
||||
|
||||
@@ -77,7 +77,6 @@ def llm_doc_from_internet_search_result(result: InternetSearchResult) -> LlmDoc:
|
||||
updated_at=datetime.now(),
|
||||
link=result.link,
|
||||
source_links={0: result.link},
|
||||
match_highlights=[],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -19,7 +19,6 @@ from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import IndexFilters
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.context.search.models import RerankingDetails
|
||||
from danswer.context.search.models import RetrievalDetails
|
||||
from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
@@ -104,7 +103,6 @@ class SearchTool(Tool):
|
||||
chunks_below: int | None = None,
|
||||
full_doc: bool = False,
|
||||
bypass_acl: bool = False,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
) -> None:
|
||||
self.user = user
|
||||
self.persona = persona
|
||||
@@ -120,9 +118,6 @@ class SearchTool(Tool):
|
||||
self.bypass_acl = bypass_acl
|
||||
self.db_session = db_session
|
||||
|
||||
# Only used via API
|
||||
self.rerank_settings = rerank_settings
|
||||
|
||||
self.chunks_above = (
|
||||
chunks_above
|
||||
if chunks_above is not None
|
||||
@@ -297,7 +292,6 @@ class SearchTool(Tool):
|
||||
self.retrieval_options.offset if self.retrieval_options else None
|
||||
),
|
||||
limit=self.retrieval_options.limit if self.retrieval_options else None,
|
||||
rerank_settings=self.rerank_settings,
|
||||
chunks_above=self.chunks_above,
|
||||
chunks_below=self.chunks_below,
|
||||
full_doc=self.full_doc,
|
||||
|
||||
@@ -33,12 +33,12 @@ def log_function_time(
|
||||
elapsed_time_str = f"{elapsed_time:.3f}"
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
# if debug_only:
|
||||
# logger.debug(final_log)
|
||||
# else:
|
||||
# # These are generally more important logs so the level is a bit higher
|
||||
# logger.notice(final_log)
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
# These are generally more important logs so the level is a bit higher
|
||||
logger.notice(final_log)
|
||||
|
||||
if not print_only:
|
||||
optional_telemetry(
|
||||
|
||||
@@ -4,17 +4,16 @@ from typing import Any
|
||||
from danswer.background.celery.tasks.beat_schedule import (
|
||||
tasks_to_schedule as base_tasks_to_schedule,
|
||||
)
|
||||
from danswer.configs.constants import DanswerCeleryTask
|
||||
|
||||
ee_tasks_to_schedule = [
|
||||
{
|
||||
"name": "autogenerate_usage_report",
|
||||
"task": DanswerCeleryTask.AUTOGENERATE_USAGE_REPORT_TASK,
|
||||
"task": "autogenerate_usage_report_task",
|
||||
"schedule": timedelta(days=30), # TODO: change this to config flag
|
||||
},
|
||||
{
|
||||
"name": "check-ttl-management",
|
||||
"task": DanswerCeleryTask.CHECK_TTL_MANAGEMENT_TASK,
|
||||
"task": "check_ttl_management_task",
|
||||
"schedule": timedelta(hours=1),
|
||||
},
|
||||
]
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.process_message import ChatPacketStream
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.utils.timing import log_function_time
|
||||
from ee.danswer.server.query_and_chat.models import OneShotQAResponse
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def gather_stream_for_answer_api(
|
||||
packets: ChatPacketStream,
|
||||
) -> OneShotQAResponse:
|
||||
response = OneShotQAResponse()
|
||||
|
||||
answer = ""
|
||||
for packet in packets:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
elif isinstance(packet, QADocsResponse):
|
||||
response.docs = packet
|
||||
# Extraneous, provided for backwards compatibility
|
||||
response.rephrase = packet.rephrased_query
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
response.chat_message_id = packet.message_id
|
||||
elif isinstance(packet, LLMRelevanceFilterResponse):
|
||||
response.llm_selected_doc_indices = packet.llm_selected_doc_indices
|
||||
elif isinstance(packet, AllCitations):
|
||||
response.citations = packet.citations
|
||||
elif isinstance(packet, DanswerContexts):
|
||||
response.contexts = packet
|
||||
|
||||
if answer:
|
||||
response.answer = answer
|
||||
|
||||
return response
|
||||
@@ -155,6 +155,7 @@ def _handle_standard_answers(
|
||||
else 0,
|
||||
danswerbot_flow=True,
|
||||
slack_thread_id=slack_thread_id,
|
||||
one_shot=True,
|
||||
)
|
||||
|
||||
root_message = get_or_create_root_message(
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user