1
0
forked from github/onyx

Compare commits

...

18 Commits

Author SHA1 Message Date
pablodanswer
1285b2f4d4 update for typing 2024-09-16 12:39:58 -07:00
pablodanswer
842628771b minor robustification for search 2024-09-16 12:07:52 -07:00
pablodanswer
7a9d5bd92e minor updates 2024-09-16 11:45:35 -07:00
pablodanswer
4f3b513ccb minor update 2024-09-16 11:44:39 -07:00
pablodanswer
cd454dd780 update clarity 2024-09-16 11:37:24 -07:00
pablodanswer
9140ee99cb asdf 2024-09-16 11:26:57 -07:00
pablodanswer
a64f27c895 functional 2024-09-16 11:26:57 -07:00
pablodanswer
fdf5611a35 add back frozen message map:wq 2024-09-16 11:26:57 -07:00
pablodanswer
c4f483d100 update port for integration testing 2024-09-16 11:26:57 -07:00
pablodanswer
fc28c6b9e1 fix stubborn typing issue 2024-09-16 11:26:57 -07:00
pablodanswer
33e25dbd8b clean up logs / build issues 2024-09-16 11:26:57 -07:00
pablodanswer
659e8cb69e validated + build-ready 2024-09-16 11:26:57 -07:00
pablodanswer
681175e9c3 add edits and so on 2024-09-16 11:26:57 -07:00
pablodanswer
de18ec7ea4 functional ux standing till 2024-09-16 11:26:57 -07:00
pablodanswer
9edbb0806d add back image citations 2024-09-16 11:26:57 -07:00
pablodanswer
63d10e7482 functional search and chat once again! 2024-09-16 11:26:57 -07:00
pablodanswer
ff6a15b5af squash 2024-09-16 11:26:57 -07:00
pablodanswer
49397e8a86 add sequential tool calls 2024-09-16 11:26:57 -07:00
25 changed files with 1365 additions and 653 deletions

View File

@@ -0,0 +1,65 @@
"""single tool call per message
Revision ID: 4e8e7ae58189
Revises: 5c7fdadae813
Create Date: 2024-09-09 10:07:58.008838
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4e8e7ae58189"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the new column
op.add_column(
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_tool_call",
"chat_message",
"tool_call",
["tool_call_id"],
["id"],
)
# Migrate existing data
op.execute(
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
)
# Drop the old relationship
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "message_id")
# Add a unique constraint to ensure one-to-one relationship
op.create_unique_constraint(
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
)
def downgrade() -> None:
# Add back the old column
op.add_column(
"tool_call",
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.create_foreign_key(
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
)
# Migrate data back
op.execute(
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
)
# Drop the new column
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "tool_call_id")

View File

@@ -48,6 +48,8 @@ class QADocsResponse(RetrievalDocs):
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
FINISHED = "finished"
NEW_RESPONSE = "new_response"
class StreamStopInfo(BaseModel):

View File

@@ -18,6 +18,8 @@ from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
@@ -617,6 +619,11 @@ def stream_chat_message_objects(
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
@@ -662,86 +669,168 @@ def stream_chat_message_objects(
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
yielded_message_id_info = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if isinstance(packet, StreamStopInfo):
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
break
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
db_citations = None
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
if reference_db_search_docs:
db_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
# Saving Gen AI answer and responding with message info
if tool_result is None:
tool_call = None
else:
tool_call = ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=cast(
QADocsResponse, qa_docs_response
).rephrased_query
if qa_docs_response is not None
else None,
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=cast(MessageSpecificCitations, db_citations).citation_map
if db_citations is not None
else None,
error=None,
tool_call=tool_call,
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message.id
if user_message is not None
else gen_ai_response_message.id,
message_type=MessageType.ASSISTANT,
)
yielded_message_id_info = False
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message,
prompt_id=prompt_id,
overridden_model=overridden_model,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
db_session=db_session,
commit=False,
)
reference_db_search_docs = None
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
if not yielded_message_id_info:
yield MessageResponseIDInfo(
user_message_id=gen_ai_response_message.id,
reserved_assistant_message_id=reserved_message_id,
)
yielded_message_id_info = True
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(
CustomToolCallSummary, packet.response
)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
error_msg = str(e)
@@ -767,11 +856,8 @@ def stream_chat_message_objects(
)
yield AllCitations(citations=answer.citations)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
if answer.llm_answer == "":
return
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
@@ -786,16 +872,14 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
tool_call=ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
if tool_result
else [],
else None,
)
logger.debug("Committing messages")

View File

@@ -178,8 +178,14 @@ def delete_search_doc_message_relationship(
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
db_session.execute(stmt)
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
)
if chat_message and chat_message.tool_call_id:
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
db_session.execute(stmt)
chat_message.tool_call_id = None
db_session.commit()
@@ -388,7 +394,7 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
stmt = stmt.options(joinedload(ChatMessage.tool_call))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -474,7 +480,7 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
tool_call: ToolCall | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
@@ -494,7 +500,7 @@ def create_new_chat_message(
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.tool_call = tool_call if tool_call else None
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
@@ -513,7 +519,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
tool_call=tool_call if tool_call else None,
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
@@ -747,14 +753,13 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)

View File

@@ -854,10 +854,8 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
message: Mapped["ChatMessage"] = relationship(
"ChatMessage", back_populates="tool_calls"
"ChatMessage", back_populates="tool_call"
)
@@ -984,9 +982,14 @@ class ChatMessage(Base):
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
tool_call_id: Mapped[int | None] = mapped_column(
ForeignKey("tool_call.id"), nullable=True
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_call: Mapped["ToolCall"] = relationship(
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",

View File

@@ -16,6 +16,7 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.constants import MessageType
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
@@ -68,6 +69,7 @@ from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
from shared_configs.configs import MAX_TOOL_CALLS
logger = setup_logger()
@@ -161,6 +163,10 @@ class Answer:
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
self.final_context_docs: list = []
self.current_streamed_output: list = []
self.processing_stream: list = []
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
) -> None:
@@ -196,128 +202,213 @@ class Answer:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_calls = 0
initiated = False
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
content="",
)
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
# if tool calling is supported, first try the raw message
# to see if we don't need to use any tools
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
if initiated:
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
initiated = True
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
if tool_call_chunk is None:
tool_call_chunk = message
else:
tool_call_chunk += message # type: ignore
else:
if message.content:
if self.is_cancelled:
return
yield cast(str, message.content)
if (
message.additional_kwargs.get("usage_metadata", {}).get("stop")
== "length"
):
yield StreamStopInfo(
stop_reason=StreamStopReason.CONTEXT_LENGTH
)
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
if not tool_call_chunk:
return # no tool call needed
tool_call_chunk: AIMessageChunk | None = None
# if we have a tool call, we need to call the tool
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
yield from tool_runner.tool_responses()
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
tool_call_chunk = AIMessageChunk(content="")
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
default_build_user_message(
self.question,
self.prompt_config,
self.latest_query_files,
tool_calls,
)
)
yield tool_runner.tool_final_result()
prompt = prompt_builder.build()
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
existing_message = ""
return
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
if tool_call_chunk is None:
tool_call_chunk = message
else:
if len(existing_message) > 0:
yield StreamStopInfo(
stop_reason=StreamStopReason.NEW_RESPONSE
)
existing_message = ""
tool_call_chunk += message # type: ignore
else:
if message.content:
if self.is_cancelled or tool_calls > 0:
return
existing_message += cast(str, message.content)
yield cast(str, message.content)
if (
message.additional_kwargs.get("usage_metadata", {}).get(
"stop"
)
== "length"
):
yield StreamStopInfo(
stop_reason=StreamStopReason.CONTEXT_LENGTH
)
if not tool_call_chunk:
logger.info("Skipped tool call but generated message")
return
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
tool_calls += 1
known_tools_by_name = [
tool
for tool in self.tools
if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
yield tool_kickoff
yield from tool_runner.tool_responses()
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
yield response
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
# Update message history with tool call and response
self.message_history.append(
PreviousMessage(
message=self.question,
message_type=MessageType.USER,
token_count=len(self.llm_tokenizer.encode(self.question)),
tool_call=None,
files=[],
)
)
self.message_history.append(
PreviousMessage(
message=str(tool_call_request),
message_type=MessageType.ASSISTANT,
token_count=len(
self.llm_tokenizer.encode(str(tool_call_request))
),
tool_call=None,
files=[],
)
)
self.message_history.append(
PreviousMessage(
message="\n".join(str(response) for response in tool_responses),
message_type=MessageType.SYSTEM,
token_count=len(
self.llm_tokenizer.encode(
"\n".join(str(response) for response in tool_responses)
)
),
tool_call=None,
files=[],
)
)
# Generate response based on updated message history
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
response_content = ""
for content in self._process_llm_stream(
prompt=prompt,
tools=None
# tools=[tool.tool_definition() for tool in self.tools],
):
if isinstance(content, str):
response_content += content
yield content
# Update message history with LLM response
self.message_history.append(
PreviousMessage(
message=response_content,
message_type=MessageType.ASSISTANT,
token_count=len(self.llm_tokenizer.encode(response_content)),
tool_call=None,
files=[],
)
)
# This method processes the LLM stream and yields the content or stop information
def _process_llm_stream(
@@ -346,139 +437,197 @@ class Answer:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
tool_calls = 0
initiated = False
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
if initiated:
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
)
if not tool:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
initiated = True
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{self.force_use_tool.tool_name}' not found"
)
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
)
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
tool_calls += 1
tool, tool_args = chosen_tool_and_args
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
yield response
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
img_urls = [img.url for img in img_generation_response]
yield response
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=img_urls,
)
)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final_result = tool_runner.tool_final_result()
yield final_result
# Update message history
self.message_history.extend(
[
PreviousMessage(
message=str(self.question),
message_type=MessageType.USER,
token_count=len(self.llm_tokenizer.encode(str(self.question))),
tool_call=None,
files=[],
),
PreviousMessage(
message=f"Tool used: {tool.name}",
message_type=MessageType.ASSISTANT,
token_count=len(
self.llm_tokenizer.encode(f"Tool used: {tool.name}")
),
tool_call=None,
files=[],
),
PreviousMessage(
message=str(final_result),
message_type=MessageType.SYSTEM,
token_count=len(self.llm_tokenizer.encode(str(final_result))),
tool_call=None,
files=[],
),
]
)
# Generate response based on updated message history
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
tool, tool_args = chosen_tool_and_args
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
response_content = ""
for content in self._process_llm_stream(prompt=prompt, tools=None):
if isinstance(content, str):
response_content += content
yield content
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
img_urls = [img.url for img in img_generation_response]
yield response
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=img_urls,
# Update message history with LLM response
self.message_history.append(
PreviousMessage(
message=response_content,
message_type=MessageType.ASSISTANT,
token_count=len(self.llm_tokenizer.encode(response_content)),
tool_call=None,
files=[],
)
)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final = tool_runner.tool_final_result()
yield final
prompt = prompt_builder.build()
yield from self._process_llm_stream(prompt=prompt, tools=None)
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -495,6 +644,8 @@ class Answer:
else self._raw_output_for_non_explicit_tool_calling_llms()
)
self.processing_stream = []
def _process_stream(
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
) -> AnswerStream:
@@ -535,56 +686,70 @@ class Answer:
yield message
else:
# assumes all tool responses will come first, then the final answer
break
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
stream_stop_info = None
new_kickoff = None
stream_stop_info = None
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
nonlocal new_kickoff
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
yield cast(str, message)
for item in stream:
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
yield cast(str, item)
yield cast(str, message)
for item in stream:
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
if isinstance(item, ToolCallKickoff):
new_kickoff = item
return
else:
yield cast(str, item)
yield from process_answer_stream_fn(_stream())
yield from process_answer_stream_fn(_stream())
if stream_stop_info:
yield stream_stop_info
if stream_stop_info:
yield stream_stop_info
# handle new tool call (continuation of message)
if new_kickoff:
yield new_kickoff
processed_stream = []
for processed_packet in _process_stream(output_generator):
processed_stream.append(processed_packet)
yield processed_packet
if (
isinstance(processed_packet, StreamStopInfo)
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
):
self.current_streamed_output = self.processing_stream
self.processing_stream = []
self._processed_stream = processed_stream
self.processing_stream.append(processed_packet)
yield processed_packet
self.current_streamed_output = self.processing_stream
self._processed_stream = self.processing_stream
@property
def llm_answer(self) -> str:
answer = ""
for packet in self.processed_streamed_output:
if not self._processed_stream and not self.current_streamed_output:
return ""
for packet in self.current_streamed_output or self._processed_stream or []:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
return answer
@property
def citations(self) -> list[CitationInfo]:
citations: list[CitationInfo] = []
for packet in self.processed_streamed_output:
for packet in self.current_streamed_output:
if isinstance(packet, CitationInfo):
citations.append(packet)

View File

@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallFinalResult | None
@classmethod
def from_chat_message(
@@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
)
def to_langchain_msg(self) -> BaseMessage:

View File

@@ -36,7 +36,10 @@ def default_build_system_message(
def default_build_user_message(
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
previous_tool_calls: int = 0,
) -> HumanMessage:
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
@@ -45,6 +48,12 @@ def default_build_user_message(
if prompt_config.task_prompt
else user_query
)
if previous_tool_calls > 0:
user_prompt = (
f"You have already generated the above so do not call a tool if not necessary. "
f"Remember the query is: `{user_prompt}`"
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt

View File

@@ -112,7 +112,7 @@ def translate_danswer_msg_to_langchain(
content = build_content_with_imgs(msg.message, files)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
return SystemMessage(content=content)
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
if msg.message_type == MessageType.USER:

View File

@@ -281,14 +281,17 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
def is_disconnected_sync() -> bool:
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
try:
return not future.result(timeout=0.01)
result = not future.result(timeout=0.01)
return result
except asyncio.TimeoutError:
logger.error("Asyncio timed out")
logger.error("Asyncio timed out while checking client connection")
return True
except asyncio.CancelledError:
return True
except Exception as e:
error_msg = str(e)
logger.critical(
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
f"An unexpected error occurred with the disconnect check coroutine: {error_msg}"
)
return True

View File

@@ -178,7 +178,7 @@ class ChatMessageDetail(BaseModel):
chat_session_id: int | None = None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallFinalResult | None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore

View File

@@ -4,7 +4,7 @@ from danswer.llm.utils import build_content_with_imgs
IMG_GENERATION_SUMMARY_PROMPT = """
You have just created the attached images in response to the following query: "{query}".
You have just created the most recent attached images in response to the following query: "{query}".
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
"""

View File

@@ -21,6 +21,8 @@ CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
# Tool call configs
MAX_TOOL_CALLS = 3
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512

View File

@@ -0,0 +1,44 @@
// This module handles AI message sequences - consecutive AI messages that are streamed
// separately but represent a single logical message. These utilities are used for
// processing and displaying such sequences in the chat interface.
import { Message } from "@/app/chat/interfaces";
import { DanswerDocument } from "@/lib/search/interfaces";
// Retrieves the consecutive AI messages at the end of the message history.
// This is useful for combining or processing the latest AI response sequence.
export function getConsecutiveAIMessagesAtEnd(
messageHistory: Message[]
): Message[] {
const aiMessages = [];
for (let i = messageHistory.length - 1; i >= 0; i--) {
if (messageHistory[i]?.type === "assistant") {
aiMessages.unshift(messageHistory[i]);
} else {
break;
}
}
return aiMessages;
}
// Extracts unique documents from a sequence of AI messages.
// This is used to compile a comprehensive list of referenced documents
// across multiple parts of an AI response.
export function getUniqueDocumentsFromAIMessages(
messages: Message[]
): DanswerDocument[] {
const uniqueDocumentsMap = new Map<string, DanswerDocument>();
messages.forEach((message) => {
if (message.documents) {
message.documents.forEach((doc) => {
const uniqueKey = `${doc.document_id}-${doc.chunk_ind}`;
if (!uniqueDocumentsMap.has(uniqueKey)) {
uniqueDocumentsMap.set(uniqueKey, doc);
}
});
}
});
return Array.from(uniqueDocumentsMap.values());
}

View File

@@ -101,6 +101,8 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { SEARCH_TOOL_NAME } from "./tools/constants";
import { useUser } from "@/components/user/UserProvider";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
import { Button } from "@tremor/react";
import dynamic from "next/dynamic";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -133,7 +135,6 @@ export function ChatPage({
} = useChatContext();
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
const { user, refreshUser, isLoadingUser } = useUser();
// chat session
@@ -248,13 +249,13 @@ export function ChatPage({
if (
lastMessage &&
lastMessage.type === "assistant" &&
lastMessage.toolCalls[0] &&
lastMessage.toolCalls[0].tool_result === undefined
lastMessage.toolCall &&
lastMessage.toolCall.tool_result === undefined
) {
const newCompleteMessageMap = new Map(
currentMessageMap(completeMessageDetail)
);
const updatedMessage = { ...lastMessage, toolCalls: [] };
const updatedMessage = { ...lastMessage, toolCall: null };
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
updateCompleteMessageDetail(currentSession, newCompleteMessageMap);
}
@@ -483,7 +484,7 @@ export function ChatPage({
message: "",
type: "system",
files: [],
toolCalls: [],
toolCall: null,
parentMessageId: null,
childrenMessageIds: [firstMessageId],
latestChildMessageId: firstMessageId,
@@ -510,6 +511,7 @@ export function ChatPage({
}
newCompleteMessageMap.set(message.messageId, message);
});
// if specified, make these new message the latest of the current message chain
if (makeLatestChildMessage) {
const currentMessageChain = buildLatestMessageChain(
@@ -1044,8 +1046,6 @@ export function ChatPage({
resetInputBar();
let messageUpdates: Message[] | null = null;
let answer = "";
let stopReason: StreamStopReason | null = null;
let query: string | null = null;
let retrievalType: RetrievalType =
@@ -1058,12 +1058,14 @@ export function ChatPage({
let stackTrace: string | null = null;
let finalMessage: BackendMessage | null = null;
let toolCalls: ToolCallMetadata[] = [];
let toolCall: ToolCallMetadata | null = null;
let initialFetchDetails: null | {
user_message_id: number;
assistant_message_id: number;
frozenMessageMap: Map<number, Message>;
initialDynamicParentMessage: Message;
initialDynamicAssistantMessage: Message;
} = null;
try {
@@ -1122,7 +1124,16 @@ export function ChatPage({
return new Promise((resolve) => setTimeout(resolve, ms));
};
let updateFn = (messages: Message[]) => {
return upsertToCompleteMessageMap({
messages: messages,
chatSessionId: currChatSessionId,
});
};
await delay(50);
let dynamicParentMessage: Message | null = null;
let dynamicAssistantMessage: Message | null = null;
while (!stack.isComplete || !stack.isEmpty()) {
await delay(0.5);
@@ -1156,12 +1167,12 @@ export function ChatPage({
messageUpdates = [
{
messageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
? regenerationRequest?.messageId
: user_message_id,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
];
@@ -1176,22 +1187,109 @@ export function ChatPage({
});
}
const { messageMap: currentFrozenMessageMap } =
let { messageMap: currentFrozenMessageMap } =
upsertToCompleteMessageMap({
messages: messageUpdates,
chatSessionId: currChatSessionId,
});
const frozenMessageMap = currentFrozenMessageMap;
let frozenMessageMap = currentFrozenMessageMap;
regenerationRequest?.parentMessage;
let initialDynamicParentMessage: Message = regenerationRequest
? regenerationRequest?.parentMessage
: {
messageId: user_message_id!,
message: "",
type: "user",
files: currentMessageFiles,
toolCall: null,
parentMessageId: error ? null : lastSuccessfulMessageId,
childrenMessageIds: [assistant_message_id!],
latestChildMessageId: -100,
};
let initialDynamicAssistantMessage: Message = {
messageId: assistant_message_id!,
message: "",
type: "assistant",
retrievalType,
query: finalMessage?.rephrased_query || query,
documents: finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCall: finalMessage?.tool_call || toolCall,
parentMessageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: user_message_id,
alternateAssistantID: alternativeAssistant?.id,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
};
initialFetchDetails = {
frozenMessageMap,
assistant_message_id,
user_message_id,
initialDynamicParentMessage,
initialDynamicAssistantMessage,
};
resetRegenerationState();
} else {
const { user_message_id, frozenMessageMap } = initialFetchDetails;
let {
initialDynamicParentMessage,
initialDynamicAssistantMessage,
user_message_id,
frozenMessageMap,
} = initialFetchDetails;
if (
dynamicParentMessage === null &&
dynamicAssistantMessage === null
) {
dynamicParentMessage = initialDynamicParentMessage;
dynamicAssistantMessage = initialDynamicAssistantMessage;
dynamicParentMessage.message = currMessage;
}
if (!dynamicAssistantMessage || !dynamicParentMessage) {
return;
}
if (Object.hasOwn(packet, "user_message_id")) {
let newParentMessageId = dynamicParentMessage.messageId;
const messageResponseIDInfo = packet as MessageResponseIDInfo;
for (const key in dynamicAssistantMessage) {
(dynamicParentMessage as Record<string, any>)[key] = (
dynamicAssistantMessage as Record<string, any>
)[key];
}
dynamicParentMessage.parentMessageId = newParentMessageId;
dynamicParentMessage.latestChildMessageId =
messageResponseIDInfo.reserved_assistant_message_id;
dynamicParentMessage.childrenMessageIds = [
messageResponseIDInfo.reserved_assistant_message_id,
];
dynamicParentMessage.messageId =
messageResponseIDInfo.user_message_id!;
dynamicAssistantMessage = {
messageId: messageResponseIDInfo.reserved_assistant_message_id,
type: "assistant",
message: "",
documents: [],
retrievalType: undefined,
toolCall: null,
files: [],
parentMessageId: dynamicParentMessage.messageId,
childrenMessageIds: [],
latestChildMessageId: null,
};
}
setChatState((prevState) => {
if (prevState.get(chatSessionIdRef.current!) === "loading") {
@@ -1204,37 +1302,37 @@ export function ChatPage({
});
if (Object.hasOwn(packet, "answer_piece")) {
answer += (packet as AnswerPiecePacket).answer_piece;
dynamicAssistantMessage.message += (
packet as AnswerPiecePacket
).answer_piece;
} else if (Object.hasOwn(packet, "top_documents")) {
documents = (packet as DocumentsResponse).top_documents;
dynamicAssistantMessage.documents = (
packet as DocumentsResponse
).top_documents;
dynamicAssistantMessage.retrievalType = RetrievalType.Search;
retrievalType = RetrievalType.Search;
if (documents && documents.length > 0) {
// point to the latest message (we don't know the messageId yet, which is why
// we have to use -1)
setSelectedMessageForDocDisplay(user_message_id);
}
} else if (Object.hasOwn(packet, "tool_name")) {
toolCalls = [
{
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
tool_result: (packet as ToolCallMetadata).tool_result,
},
];
dynamicAssistantMessage.toolCall = {
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
tool_result: (packet as ToolCallMetadata).tool_result,
};
if (
!toolCalls[0].tool_result ||
toolCalls[0].tool_result == undefined
dynamicAssistantMessage.toolCall.tool_name === SEARCH_TOOL_NAME
) {
dynamicAssistantMessage.query =
dynamicAssistantMessage.toolCall.tool_args.query;
}
if (
!dynamicAssistantMessage.toolCall ||
!dynamicAssistantMessage.toolCall.tool_result ||
dynamicAssistantMessage.toolCall.tool_result == undefined
) {
updateChatState("toolBuilding", frozenSessionId);
} else {
updateChatState("streaming", frozenSessionId);
}
// This will be consolidated in upcoming tool calls udpate,
// but for now, we need to set query as early as possible
if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) {
query = toolCalls[0].tool_args["query"];
}
} else if (Object.hasOwn(packet, "file_ids")) {
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
(fileId) => {
@@ -1244,82 +1342,54 @@ export function ChatPage({
};
}
);
dynamicAssistantMessage.files = aiMessageImages;
} else if (Object.hasOwn(packet, "error")) {
error = (packet as StreamingError).error;
stackTrace = (packet as StreamingError).stack_trace;
dynamicAssistantMessage.stackTrace = (
packet as StreamingError
).stack_trace;
} else if (Object.hasOwn(packet, "message_id")) {
finalMessage = packet as BackendMessage;
dynamicAssistantMessage = {
...dynamicAssistantMessage,
...finalMessage,
};
} else if (Object.hasOwn(packet, "stop_reason")) {
const stop_reason = (packet as StreamStopInfo).stop_reason;
if (stop_reason === StreamStopReason.CONTEXT_LENGTH) {
updateCanContinue(true, frozenSessionId);
}
}
if (!Object.hasOwn(packet, "stop_reason")) {
updateFn = (messages: Message[]) => {
const replacementsMap = regenerationRequest
? new Map([
[
regenerationRequest?.parentMessage?.messageId,
regenerationRequest?.parentMessage?.messageId,
],
[
dynamicParentMessage?.messageId,
dynamicAssistantMessage?.messageId,
],
] as [number, number][])
: null;
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
parentMessage =
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
return upsertToCompleteMessageMap({
messages: messages,
replacementsMap: replacementsMap,
completeMessageMapOverride: frozenMessageMap,
chatSessionId: frozenSessionId!,
});
};
const updateFn = (messages: Message[]) => {
const replacementsMap = regenerationRequest
? new Map([
[
regenerationRequest?.parentMessage?.messageId,
regenerationRequest?.parentMessage?.messageId,
],
[
regenerationRequest?.messageId,
initialFetchDetails?.assistant_message_id,
],
] as [number, number][])
: null;
return upsertToCompleteMessageMap({
messages: messages,
replacementsMap: replacementsMap,
completeMessageMapOverride: frozenMessageMap,
chatSessionId: frozenSessionId!,
});
};
updateFn([
{
messageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: initialFetchDetails.user_message_id!,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: error ? null : lastSuccessfulMessageId,
childrenMessageIds: [
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
[]),
initialFetchDetails.assistant_message_id!,
],
latestChildMessageId: initialFetchDetails.assistant_message_id,
},
{
messageId: initialFetchDetails.assistant_message_id!,
message: error || answer,
type: error ? "error" : "assistant",
retrievalType,
query: finalMessage?.rephrased_query || query,
documents:
finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
parentMessageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: initialFetchDetails.user_message_id,
alternateAssistantID: alternativeAssistant?.id,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
},
]);
let { messageMap } = updateFn([
dynamicParentMessage,
dynamicAssistantMessage,
]);
frozenMessageMap = messageMap;
}
}
}
}
@@ -1333,7 +1403,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
{
@@ -1343,7 +1413,7 @@ export function ChatPage({
message: errorMsg,
type: "error",
files: aiMessageImages || [],
toolCalls: [],
toolCall: null,
parentMessageId:
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
},
@@ -1962,9 +2032,8 @@ export function ChatPage({
completeMessageDetail
);
const messageReactComponentKey = `${i}-${currentSessionId()}`;
const parentMessage = message.parentMessageId
? messageMap.get(message.parentMessageId)
: null;
const parentMessage =
i > 1 ? messageHistory[i - 1] : null;
if (message.type === "user") {
if (
(currentSessionChatState == "loading" &&
@@ -2055,6 +2124,25 @@ export function ChatPage({
) {
return <></>;
}
const mostRecentNonAIParent = messageHistory
.slice(0, i)
.reverse()
.find((msg) => msg.type !== "assistant");
const hasChildMessage =
message.latestChildMessageId !== null &&
message.latestChildMessageId !== undefined;
const childMessage = hasChildMessage
? messageMap.get(
message.latestChildMessageId!
)
: null;
const hasParentAI =
parentMessage?.type == "assistant";
const hasChildAI =
childMessage?.type == "assistant";
return (
<div
id={`message-${message.messageId}`}
@@ -2066,6 +2154,9 @@ export function ChatPage({
}
>
<AIMessage
setPopup={setPopup}
hasChildAI={hasChildAI}
hasParentAI={hasParentAI}
continueGenerating={
i == messageHistory.length - 1 &&
currentCanContinue()
@@ -2075,7 +2166,7 @@ export function ChatPage({
overriddenModel={message.overridden_model}
regenerate={createRegenerator({
messageId: message.messageId,
parentMessage: parentMessage!,
parentMessage: mostRecentNonAIParent!,
})}
otherMessagesCanSwitchTo={
parentMessage?.childrenMessageIds || []
@@ -2112,18 +2203,15 @@ export function ChatPage({
}
messageId={message.messageId}
content={message.message}
// content={message.message}
files={message.files}
query={
messageHistory[i]?.query || undefined
}
personaName={liveAssistant.name}
citedDocuments={getCitedDocumentsFromMessage(
message
)}
toolCall={
message.toolCalls &&
message.toolCalls[0]
message.toolCall && message.toolCall
}
isComplete={
i !== messageHistory.length - 1 ||
@@ -2147,7 +2235,6 @@ export function ChatPage({
])
}
handleSearchQueryEdit={
i === messageHistory.length - 1 &&
currentSessionChatState == "input"
? (newQuery) => {
if (!previousMessage) {
@@ -2231,7 +2318,6 @@ export function ChatPage({
<AIMessage
currentPersona={liveAssistant}
messageId={message.messageId}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
@@ -2279,7 +2365,6 @@ export function ChatPage({
alternativeAssistant
}
messageId={null}
personaName={liveAssistant.name}
content={
<div
key={"Generating"}
@@ -2299,7 +2384,6 @@ export function ChatPage({
<AIMessage
currentPersona={liveAssistant}
messageId={-1}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{loadingError}

View File

@@ -85,7 +85,7 @@ export interface Message {
documents?: DanswerDocument[] | null;
citations?: CitationMap;
files: FileDescriptor[];
toolCalls: ToolCallMetadata[];
toolCall: ToolCallMetadata | null;
// for rebuilding the message tree
parentMessageId: number | null;
childrenMessageIds?: number[];
@@ -120,7 +120,7 @@ export interface BackendMessage {
time_sent: string;
citations: CitationMap;
files: FileDescriptor[];
tool_calls: ToolCallFinalResult[];
tool_call: ToolCallFinalResult | null;
alternate_assistant_id?: number | null;
overridden_model?: string;
}
@@ -143,3 +143,10 @@ export interface StreamingError {
error: string;
stack_trace: string;
}
export interface ImageGenerationResult {
revised_prompt: string;
url: string;
}
export type ImageGenerationResults = ImageGenerationResult[];

View File

@@ -435,7 +435,7 @@ export function processRawChatHistory(
citations: messageInfo?.citations || {},
}
: {}),
toolCalls: messageInfo.tool_calls,
toolCall: messageInfo.tool_call,
parentMessageId: messageInfo.parent_message,
childrenMessageIds: [],
latestChildMessageId: messageInfo.latest_child_message,
@@ -479,6 +479,7 @@ export function buildLatestMessageChain(
let currMessage: Message | null = rootMessage;
while (currMessage) {
finalMessageList.push(currMessage);
const childMessageNumber = currMessage.latestChildMessageId;
if (childMessageNumber && messageMap.has(childMessageNumber)) {
currMessage = messageMap.get(childMessageNumber) as Message;

View File

@@ -8,25 +8,22 @@ import {
FiGlobe,
} from "react-icons/fi";
import { FeedbackType } from "../types";
import {
Dispatch,
SetStateAction,
useContext,
useEffect,
useRef,
useState,
} from "react";
import { useContext, useEffect, useRef, useState } from "react";
import ReactMarkdown from "react-markdown";
import {
DanswerDocument,
FilteredDanswerDocument,
} from "@/lib/search/interfaces";
import { SearchSummary } from "./SearchSummary";
import { SourceIcon } from "@/components/SourceIcon";
import { SkippedSearch } from "./SkippedSearch";
import remarkGfm from "remark-gfm";
import { CopyButton } from "@/components/CopyButton";
import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
import {
ChatFileType,
FileDescriptor,
ImageGenerationResults,
ToolCallMetadata,
} from "../interfaces";
import {
IMAGE_GENERATION_TOOL_NAME,
SEARCH_TOOL_NAME,
@@ -44,13 +41,11 @@ import "./custom-code-styles.css";
import { Persona } from "@/app/admin/assistants/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Citation } from "@/components/search/results/Citation";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import {
ThumbsUpIcon,
ThumbsDownIcon,
LikeFeedback,
DislikeFeedback,
ToolCallIcon,
} from "@/components/icons/icons";
import {
CustomTooltip,
@@ -59,12 +54,14 @@ import {
import { ValidSources } from "@/lib/types";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { useMouseTracking } from "./hooks";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage";
import DualPromptDisplay from "../tools/ImagePromptCitaiton";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { Popover } from "@/components/popover/Popover";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -121,6 +118,8 @@ function FileDisplay({
}
export const AIMessage = ({
hasChildAI,
hasParentAI,
regenerate,
overriddenModel,
continueGenerating,
@@ -134,7 +133,6 @@ export const AIMessage = ({
files,
selectedDocuments,
query,
personaName,
citedDocuments,
toolCall,
isComplete,
@@ -148,7 +146,10 @@ export const AIMessage = ({
currentPersona,
otherMessagesCanSwitchTo,
onMessageSelection,
setPopup,
}: {
hasChildAI?: boolean;
hasParentAI?: boolean;
shared?: boolean;
isActive?: boolean;
continueGenerating?: () => void;
@@ -163,9 +164,8 @@ export const AIMessage = ({
content: string | JSX.Element;
files?: FileDescriptor[];
query?: string;
personaName?: string;
citedDocuments?: [string, DanswerDocument][] | null;
toolCall?: ToolCallMetadata;
toolCall?: ToolCallMetadata | null;
isComplete?: boolean;
hasDocs?: boolean;
handleFeedback?: (feedbackType: FeedbackType) => void;
@@ -176,7 +176,10 @@ export const AIMessage = ({
retrievalDisabled?: boolean;
overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
setPopup?: (popupSpec: PopupSpec | null) => void;
}) => {
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const toolCallGenerating = toolCall && !toolCall.tool_result;
const processContent = (content: string | JSX.Element) => {
if (typeof content !== "string") {
@@ -199,12 +202,20 @@ export const AIMessage = ({
return content;
}
}
if (
isComplete &&
toolCall?.tool_result &&
toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME
) {
return content + ` [${toolCall.tool_name}]()`;
}
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
};
const finalContent = processContent(content as string);
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
const settings = useContext(SettingsContext);
@@ -274,39 +285,50 @@ export const AIMessage = ({
<div
id="danswer-ai-message"
ref={trackedElementRef}
className={"py-5 ml-4 px-5 relative flex "}
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
>
<div
className={`mx-auto ${shared ? "w-full" : "w-[90%]"} max-w-message-max`}
>
<div className={`desktop:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
<div className="flex">
<AssistantIcon
size="small"
assistant={alternativeAssistant || currentPersona}
/>
{!hasParentAI ? (
<AssistantIcon
size="small"
assistant={alternativeAssistant || currentPersona}
/>
) : (
<div className="w-6" />
)}
<div className="w-full">
<div className="max-w-message-max break-words">
<div className="w-full ml-4">
<div className="max-w-message-max break-words">
{!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME ? (
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
<>
{query !== undefined &&
handleShowRetrieved !== undefined &&
isCurrentlyShowingRetrieved !== undefined &&
!retrievalDisabled && (
<div className="mb-1">
<SearchSummary
docs={docs}
filteredDocs={filteredDocs}
query={query}
finished={toolCall?.tool_result != undefined}
hasDocs={hasDocs || false}
messageId={messageId}
handleShowRetrieved={handleShowRetrieved}
finished={
toolCall?.tool_result != undefined ||
isComplete!
}
toggleDocumentSelection={
toggleDocumentSelection
}
handleSearchQueryEdit={handleSearchQueryEdit}
/>
</div>
)}
{handleForceSearch &&
!hasChildAI &&
content &&
query === undefined &&
!hasDocs &&
@@ -318,7 +340,7 @@ export const AIMessage = ({
</div>
)}
</>
) : null}
)}
{toolCall &&
!TOOLS_WITH_CUSTOM_HANDLING.includes(
@@ -371,6 +393,50 @@ export const AIMessage = ({
const { node, ...rest } = props;
const value = rest.children;
if (
value
?.toString()
.startsWith(IMAGE_GENERATION_TOOL_NAME)
) {
const imageGenerationResult =
toolCall?.tool_result as ImageGenerationResults;
return (
<Popover
open={isPopoverOpen}
onOpenChange={
() => null
// setIsPopoverOpen(isPopoverOpen => !isPopoverOpen)
} // only allow closing from the icon
content={
<button
onMouseDown={() => {
setIsPopoverOpen(!isPopoverOpen);
}}
>
<ToolCallIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
</button>
}
popover={
<DualPromptDisplay
arg="Prompt"
setPopup={setPopup!}
prompt1={
imageGenerationResult[0]
.revised_prompt
}
prompt2={
imageGenerationResult[1]
.revised_prompt
}
/>
}
side="top"
align="center"
/>
);
}
if (value?.toString().startsWith("*")) {
return (
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
@@ -378,10 +444,6 @@ export const AIMessage = ({
} else if (
value?.toString().startsWith("[")
) {
// for some reason <a> tags cause the onClick to not apply
// and the links are unclickable
// TODO: fix the fact that you have to double click to follow link
// for the first link
return (
<Citation link={rest?.href}>
{rest.children}
@@ -428,82 +490,10 @@ export const AIMessage = ({
) : isComplete ? null : (
<></>
)}
{isComplete && docs && docs.length > 0 && (
<div className="mt-2 -mx-8 w-full mb-4 flex relative">
<div className="w-full">
<div className="px-8 flex gap-x-2">
{!settings?.isMobile &&
filteredDocs.length > 0 &&
filteredDocs.slice(0, 2).map((doc, ind) => (
<div
key={doc.document_id}
className={`w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 pb-2 pt-1 border-b
`}
>
<a
href={doc.link || undefined}
target="_blank"
className="text-sm flex w-full pt-1 gap-x-1.5 overflow-hidden justify-between font-semibold text-text-700"
>
<Citation link={doc.link} index={ind + 1} />
<p className="shrink truncate ellipsis break-all">
{doc.semantic_identifier ||
doc.document_id}
</p>
<div className="ml-auto flex-none">
{doc.is_internet ? (
<InternetSearchIcon url={doc.link} />
) : (
<SourceIcon
sourceType={doc.source_type}
iconSize={18}
/>
)}
</div>
</a>
<div className="flex overscroll-x-scroll mt-.5">
<DocumentMetadataBlock document={doc} />
</div>
<div className="line-clamp-3 text-xs break-words pt-1">
{doc.blurb}
</div>
</div>
))}
<div
onClick={() => {
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={-1}
className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b"
>
<div className="text-sm flex justify-between font-semibold text-text-700">
<p className="line-clamp-1">See context</p>
<div className="flex gap-x-1">
{uniqueSources.map((sourceType, ind) => {
return (
<div key={ind} className="flex-none">
<SourceIcon
sourceType={sourceType}
iconSize={18}
/>
</div>
);
})}
</div>
</div>
<div className="line-clamp-3 text-xs break-words pt-1">
See more
</div>
</div>
</div>
</div>
</div>
)}
</div>
{handleFeedback &&
{!hasChildAI &&
handleFeedback &&
(isActive ? (
<div
className={`

View File

@@ -4,9 +4,21 @@ import {
} from "@/components/BasicClickable";
import { HoverPopup } from "@/components/HoverPopup";
import { Hoverable } from "@/components/Hoverable";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { SourceIcon } from "@/components/SourceIcon";
import { ChevronDownIcon, InfoIcon } from "@/components/icons/icons";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import { Citation } from "@/components/search/results/Citation";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { useEffect, useRef, useState } from "react";
import {
DanswerDocument,
FilteredDanswerDocument,
} from "@/lib/search/interfaces";
import { ValidSources } from "@/lib/types";
import { useContext, useEffect, useRef, useState } from "react";
import { FiCheck, FiEdit2, FiSearch, FiX } from "react-icons/fi";
import { DownChevron } from "react-select/dist/declarations/src/components/indicators";
export function ShowHideDocsButton({
messageId,
@@ -37,24 +49,31 @@ export function ShowHideDocsButton({
export function SearchSummary({
query,
hasDocs,
filteredDocs,
finished,
messageId,
handleShowRetrieved,
docs,
toggleDocumentSelection,
handleSearchQueryEdit,
}: {
toggleDocumentSelection?: () => void;
docs?: DanswerDocument[] | null;
filteredDocs: FilteredDanswerDocument[];
finished: boolean;
query: string;
hasDocs: boolean;
messageId: number | null;
handleShowRetrieved: (messageId: number | null) => void;
handleSearchQueryEdit?: (query: string) => void;
}) {
const [isEditing, setIsEditing] = useState(false);
const [finalQuery, setFinalQuery] = useState(query);
const [isOverflowed, setIsOverflowed] = useState(false);
const searchingForRef = useRef<HTMLDivElement>(null);
const editQueryRef = useRef<HTMLInputElement>(null);
const [isDropdownOpen, setIsDropdownOpen] = useState(false);
const searchingForRef = useRef<HTMLDivElement | null>(null);
const editQueryRef = useRef<HTMLInputElement | null>(null);
const settings = useContext(SettingsContext);
const toggleDropdown = () => {
setIsDropdownOpen(!isDropdownOpen);
};
useEffect(() => {
const checkOverflow = () => {
@@ -68,7 +87,7 @@ export function SearchSummary({
};
checkOverflow();
window.addEventListener("resize", checkOverflow); // Recheck on window resize
window.addEventListener("resize", checkOverflow);
return () => window.removeEventListener("resize", checkOverflow);
}, []);
@@ -86,15 +105,30 @@ export function SearchSummary({
}, [query]);
const searchingForDisplay = (
<div className={`flex p-1 rounded ${isOverflowed && "cursor-default"}`}>
<FiSearch className="flex-none mr-2 my-auto" size={14} />
<div
className={`${!finished && "loading-text"}
!text-sm !line-clamp-1 !break-all px-0.5`}
ref={searchingForRef}
>
{finished ? "Searched" : "Searching"} for: <i> {finalQuery}</i>
</div>
<div
className={`flex my-auto items-center ${isOverflowed && "cursor-default"}`}
>
{finished ? (
<>
<div
onClick={() => {
toggleDropdown();
}}
className={`transition-colors duration-300 group-hover:text-text-toolhover cursor-pointer text-text-toolrun !line-clamp-1 !break-all pr-0.5`}
ref={searchingForRef}
>
Searched {filteredDocs.length > 0 && filteredDocs.length} document
{filteredDocs.length != 1 && "s"} for {query}
</div>
</>
) : (
<div
className={`loading-text !text-sm !line-clamp-1 !break-all px-0.5`}
ref={searchingForRef}
>
Searching for: <i> {finalQuery}</i>
</div>
)}
</div>
);
@@ -145,43 +179,126 @@ export function SearchSummary({
</div>
</div>
) : null;
const SearchBlock = ({ doc, ind }: { doc: DanswerDocument; ind: number }) => {
return (
<div
onClick={() => {
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={doc.document_id}
className={`flex items-start gap-3 px-4 py-3 text-token-text-secondary ${ind == 0 && "rounded-t-xl"} hover:bg-background-100 group relative text-sm`}
>
<div className="mt-1 scale-[.9] flex-none">
{doc.is_internet ? (
<InternetSearchIcon url={doc.link} />
) : (
<SourceIcon sourceType={doc.source_type} iconSize={18} />
)}
</div>
<div className="flex flex-col">
<a
href={doc.link}
target="_blank"
className="line-clamp-1 text-text-900"
>
<p className="shrink truncate ellipsis break-all ">
{doc.semantic_identifier || doc.document_id}
</p>
<p className="line-clamp-3 text-text-500 break-words">
{doc.blurb}
</p>
</a>
</div>
</div>
);
};
return (
<div className="flex">
{isEditing ? (
editInput
) : (
<>
<div className="text-sm">
{isOverflowed ? (
<HoverPopup
mainContent={searchingForDisplay}
popupContent={
<div>
<b>Full query:</b>{" "}
<div className="mt-1 italic">{query}</div>
</div>
}
direction="top"
<>
<div className="flex gap-x-2 group">
{isEditing ? (
editInput
) : (
<>
<div className="my-auto text-sm">
{isOverflowed ? (
<HoverPopup
mainContent={searchingForDisplay}
popupContent={
<div>
<b>Full query:</b>{" "}
<div className="mt-1 italic">{query}</div>
</div>
}
direction="top"
/>
) : (
searchingForDisplay
)}
</div>
<button
className="my-auto invisible group-hover:visible transition-all duration-300 rounded"
onClick={toggleDropdown}
>
<ChevronDownIcon
className={`transform transition-transform ${isDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
{handleSearchQueryEdit ? (
<Tooltip delayDuration={1000} content={"Edit Search"}>
<button
className="my-auto invisible group-hover:visible transition-all duration-300 cursor-pointer rounded"
onClick={() => {
setIsEditing(true);
}}
>
<FiEdit2 />
</button>
</Tooltip>
) : (
searchingForDisplay
<></>
)}
</div>
{handleSearchQueryEdit && (
<Tooltip delayDuration={1000} content={"Edit Search"}>
<button
className="my-auto hover:bg-hover p-1.5 rounded"
</>
)}
</div>
{isDropdownOpen && docs && docs.length > 0 && (
<div
className={`mt-2 -mx-8 w-full mb-4 flex relative transition-all duration-300 ${isDropdownOpen ? "opacity-100 max-h-[1000px]" : "opacity-0 max-h-0"}`}
>
<div className="w-full">
<div className="mx-8 flex rounded max-h-[500px] overflow-y-scroll rounded-lg border-1.5 border divide-y divider-y-1.5 divider-y-border border-border flex-col gap-x-4">
{!settings?.isMobile &&
filteredDocs.length > 0 &&
filteredDocs.map((doc, ind) => (
<SearchBlock key={ind} doc={doc} ind={ind} />
))}
<div
onClick={() => {
setIsEditing(true);
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={-1}
className="cursor-pointer w-full flex transition-all duration-500 hover:bg-background-100 py-3 border-b"
>
<FiEdit2 />
</button>
</Tooltip>
)}
</>
<div key={0} className="px-3 invisible scale-[.9] flex-none">
<SourceIcon sourceType={"file"} iconSize={18} />
</div>
<div className="text-sm flex justify-between text-text-900">
<p className="line-clamp-1">See context</p>
<div className="flex gap-x-1"></div>
</div>
</div>
</div>
</div>
</div>
)}
</div>
</>
);
}

View File

@@ -101,7 +101,6 @@ export function SharedChatDisplay({
messageId={message.messageId}
content={message.message}
files={message.files || []}
personaName={chatSession.persona_name}
citedDocuments={getCitedDocumentsFromMessage(message)}
isComplete
/>

View File

@@ -0,0 +1,83 @@
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { CopyIcon } from "@/components/icons/icons";
import { Divider } from "@tremor/react";
import React, { forwardRef, useState } from "react";
import { FiCheck } from "react-icons/fi";
interface PromptDisplayProps {
prompt1: string;
prompt2?: string;
arg: string;
setPopup: (popupSpec: PopupSpec | null) => void;
}
const DualPromptDisplay = forwardRef<HTMLDivElement, PromptDisplayProps>(
({ prompt1, prompt2, setPopup, arg }, ref) => {
const [copied, setCopied] = useState<number | null>(null);
const copyToClipboard = (text: string, index: number) => {
navigator.clipboard
.writeText(text)
.then(() => {
setPopup({ message: "Copied to clipboard", type: "success" });
setCopied(index);
setTimeout(() => setCopied(null), 2000); // Reset copy status after 2 seconds
})
.catch((err) => {
setPopup({ message: "Failed to copy", type: "error" });
});
};
const PromptSection = ({
copied,
prompt,
index,
}: {
copied: number | null;
prompt: string;
index: number;
}) => (
<div className="w-full p-2 rounded-lg">
<h2 className="text-lg font-semibold mb-2">
{arg} {index + 1}
</h2>
<p className="line-clamp-6 text-sm text-gray-800">{prompt}</p>
<button
onMouseDown={() => copyToClipboard(prompt, index)}
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
>
{copied == index ? (
<>
<FiCheck className="mr-2" size={16} />
Copied!
</>
) : (
<>
<CopyIcon className="mr-2" size={16} />
Copy
</>
)}
</button>
</div>
);
return (
<div className="w-[400px] bg-inverted mx-auto p-6 rounded-lg shadow-lg">
<div className="flex flex-col gap-x-4">
<PromptSection copied={copied} prompt={prompt1} index={0} />
{prompt2 && (
<>
<Divider />
<PromptSection copied={copied} prompt={prompt2} index={1} />
</>
)}
</div>
</div>
);
}
);
DualPromptDisplay.displayName = "DualPromptDisplay";
export default DualPromptDisplay;

View File

@@ -145,6 +145,19 @@ export function ClientLayout({
),
link: "/admin/tools",
},
...(enableEnterprise
? [
{
name: (
<div className="flex">
<ClipboardIcon size={18} />
<div className="ml-1">Standard Answers</div>
</div>
),
link: "/admin/standard-answer",
},
]
: []),
{
name: (
<div className="flex">

View File

@@ -2811,3 +2811,40 @@ export const WindowsIcon = ({
</svg>
);
};
export const ToolCallIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 19 15"
fill="none"
>
<path
d="M4.42 0.75H2.8625H2.75C1.64543 0.75 0.75 1.64543 0.75 2.75V11.65C0.75 12.7546 1.64543 13.65 2.75 13.65H2.8625C2.8625 13.65 2.8625 13.65 2.8625 13.65C2.8625 13.65 4.00751 13.65 4.42 13.65M13.98 13.65H15.5375H15.65C16.7546 13.65 17.65 12.7546 17.65 11.65V2.75C17.65 1.64543 16.7546 0.75 15.65 0.75H15.5375H13.98"
stroke="currentColor"
strokeWidth="1.5"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M5.55283 4.21963C5.25993 3.92674 4.78506 3.92674 4.49217 4.21963C4.19927 4.51252 4.19927 4.9874 4.49217 5.28029L6.36184 7.14996L4.49217 9.01963C4.19927 9.31252 4.19927 9.7874 4.49217 10.0803C4.78506 10.3732 5.25993 10.3732 5.55283 10.0803L7.95283 7.68029C8.24572 7.3874 8.24572 6.91252 7.95283 6.61963L5.55283 4.21963Z"
fill="currentColor"
stroke="currentColor"
strokeWidth="0.2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M9.77753 8.75003C9.3357 8.75003 8.97753 9.10821 8.97753 9.55003C8.97753 9.99186 9.3357 10.35 9.77753 10.35H13.2775C13.7194 10.35 14.0775 9.99186 14.0775 9.55003C14.0775 9.10821 13.7194 8.75003 13.2775 8.75003H9.77753Z"
fill="currentColor"
stroke="currentColor"
strokeWidth="0.1"
/>
</svg>
);
};

View File

@@ -3,7 +3,6 @@
import {
DanswerDocument,
DocumentRelevance,
Relevance,
SearchDanswerDocument,
} from "@/lib/search/interfaces";
import { DocumentFeedbackBlock } from "./DocumentFeedbackBlock";
@@ -12,11 +11,10 @@ import { PopupSpec } from "../admin/connectors/Popup";
import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge";
import { SourceIcon } from "../SourceIcon";
import { MetadataBadge } from "../MetadataBadge";
import { BookIcon, CheckmarkIcon, LightBulbIcon, XIcon } from "../icons/icons";
import { BookIcon, LightBulbIcon } from "../icons/icons";
import { FaStar } from "react-icons/fa";
import { FiTag } from "react-icons/fi";
import { DISABLE_LLM_DOC_RELEVANCE } from "@/lib/constants";
import { SettingsContext } from "../settings/SettingsProvider";
import { CustomTooltip, TooltipGroup } from "../tooltip/CustomTooltip";
import { WarningCircle } from "@phosphor-icons/react";

View File

@@ -22,6 +22,8 @@ export interface AnswerPiecePacket {
export enum StreamStopReason {
CONTEXT_LENGTH = "CONTEXT_LENGTH",
CANCELLED = "CANCELLED",
FINISHED = "FINISHED",
NEW_RESPONSE = "NEW_RESPONSE",
}
export interface StreamStopInfo {