mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 12:15:48 +00:00
Compare commits
8 Commits
v3.0.0-clo
...
new_seq_to
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
59e9a33b30 | ||
|
|
6e60437c56 | ||
|
|
9cde51f1a2 | ||
|
|
8b8952f117 | ||
|
|
dc01eea610 | ||
|
|
c89d8318c0 | ||
|
|
3f2d6557dc | ||
|
|
b3818877af |
@@ -48,6 +48,7 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
NEW_RESPONSE = "new_response"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
|
||||
@@ -19,6 +19,7 @@ from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
@@ -137,6 +138,7 @@ def _translate_citations(
|
||||
"""Always cites the first instance of the document_id, assumes the db_docs
|
||||
are sorted in the order displayed in the UI"""
|
||||
doc_id_to_saved_doc_id_map: dict[str, int] = {}
|
||||
|
||||
for db_doc in db_docs:
|
||||
if db_doc.document_id not in doc_id_to_saved_doc_id_map:
|
||||
doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id
|
||||
@@ -687,6 +689,10 @@ def stream_chat_message_objects(
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
tool_name_to_tool_id = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
@@ -729,6 +735,74 @@ def stream_chat_message_objects(
|
||||
tool_result = None
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
db_citations = None
|
||||
|
||||
if reference_db_search_docs:
|
||||
db_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=cast(
|
||||
QADocsResponse, qa_docs_response
|
||||
).rephrased_query
|
||||
if qa_docs_response is not None
|
||||
else None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=cast(MessageSpecificCitations, db_citations).citation_map
|
||||
if db_citations is not None
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message.id
|
||||
if user_message is not None
|
||||
else gen_ai_response_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
reference_db_search_docs = None
|
||||
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
@@ -869,6 +943,8 @@ def stream_chat_message_objects(
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
if answer.llm_answer == "":
|
||||
return
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
|
||||
@@ -9,6 +9,8 @@ from langchain_core.messages import ToolCall
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
@@ -118,6 +120,9 @@ class Answer:
|
||||
)
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
self.current_streamed_output: list = []
|
||||
|
||||
self.processing_stream: list = []
|
||||
|
||||
def _get_tools_list(self) -> list[Tool]:
|
||||
if not self.force_use_tool.force_use:
|
||||
@@ -155,6 +160,7 @@ class Answer:
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
tool_handler, DummyAnswerResponseHandler(), self.is_cancelled
|
||||
)
|
||||
|
||||
yield from response_handler_manager.handle_llm_response(
|
||||
iter([dummy_tool_call_chunk])
|
||||
)
|
||||
@@ -165,7 +171,13 @@ class Answer:
|
||||
else:
|
||||
raise RuntimeError("Tool call handler did not return a new LLM call")
|
||||
|
||||
def _get_response(self, llm_calls: list[LLMCall]) -> AnswerStream:
|
||||
def _get_response(
|
||||
self,
|
||||
llm_calls: list[LLMCall],
|
||||
check_for_tool_call: bool = False,
|
||||
previously_used_tool: Tool | None = None,
|
||||
previous_tool_response: ToolResponse | None = None,
|
||||
) -> AnswerStream:
|
||||
current_llm_call = llm_calls[-1]
|
||||
|
||||
# handle the case where no decision has to be made; we simply run the tool
|
||||
@@ -231,7 +243,6 @@ class Answer:
|
||||
tool_call_handler, answer_handler, self.is_cancelled
|
||||
)
|
||||
|
||||
# DEBUG: good breakpoint
|
||||
stream = self.llm.stream(
|
||||
prompt=current_llm_call.prompt_builder.build(),
|
||||
tools=[tool.tool_definition() for tool in current_llm_call.tools] or None,
|
||||
@@ -242,11 +253,101 @@ class Answer:
|
||||
),
|
||||
structured_response_format=self.answer_style_config.structured_response_format,
|
||||
)
|
||||
yield from response_handler_manager.handle_llm_response(stream)
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(current_llm_call)
|
||||
tool_call_made = False
|
||||
tool_call_name: str | None = None
|
||||
buffered_packets = []
|
||||
|
||||
tool_response = None
|
||||
for packet in response_handler_manager.handle_llm_response(stream):
|
||||
if isinstance(packet, DanswerAnswerPiece):
|
||||
pass
|
||||
|
||||
if isinstance(packet, ToolResponse):
|
||||
tool_response = packet
|
||||
|
||||
if check_for_tool_call:
|
||||
buffered_packets.append(packet)
|
||||
if isinstance(packet, ToolCallKickoff):
|
||||
# if has_streamed_text and not has_completed:
|
||||
# yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
# has_completed = True
|
||||
|
||||
tool_call_name = packet.tool_name
|
||||
tool_call_made = True
|
||||
for buffered_packet in buffered_packets:
|
||||
yield buffered_packet
|
||||
buffered_packets = []
|
||||
else:
|
||||
yield packet
|
||||
if isinstance(packet, ToolCallKickoff):
|
||||
# if has_streamed_text and not has_completed:
|
||||
# yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
# has_completed = True
|
||||
tool_call_name = packet.tool_name
|
||||
tool_call_made = True
|
||||
|
||||
if check_for_tool_call and not tool_call_made:
|
||||
for remaining_packet in buffered_packets:
|
||||
yield remaining_packet
|
||||
return
|
||||
|
||||
for remaining_packet in buffered_packets:
|
||||
yield remaining_packet
|
||||
|
||||
new_llm_call = response_handler_manager.next_llm_call(
|
||||
current_llm_call, tool_call_made
|
||||
)
|
||||
tool_used: Tool | None = None
|
||||
if tool_call_made:
|
||||
tool_used = next(
|
||||
(tool for tool in self.tools if tool.name == tool_call_name), None
|
||||
)
|
||||
|
||||
if new_llm_call:
|
||||
yield from self._get_response(llm_calls + [new_llm_call])
|
||||
yield from self._get_response(
|
||||
llm_calls + [new_llm_call],
|
||||
check_for_tool_call=not tool_call_made,
|
||||
previously_used_tool=tool_used,
|
||||
previous_tool_response=tool_response,
|
||||
)
|
||||
|
||||
else:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
|
||||
# Logic here
|
||||
if (
|
||||
not check_for_tool_call
|
||||
and not tool_call_made
|
||||
and not previously_used_tool
|
||||
):
|
||||
return
|
||||
|
||||
if previously_used_tool:
|
||||
previously_used_tool.build_prompt_after_tool_call(
|
||||
current_llm_call.prompt_builder,
|
||||
self.question,
|
||||
self.llm_answer,
|
||||
previous_tool_response,
|
||||
)
|
||||
# Build next prompter with the original question and the LLM's last answer
|
||||
# current_llm_call.prompt_builder.update_user_prompt(HumanMessage(content=self.question))
|
||||
# current_llm_call.prompt_builder.build_next_prompter(self.question, self.llm_answer)
|
||||
|
||||
llm_call = LLMCall(
|
||||
prompt_builder=current_llm_call.prompt_builder,
|
||||
tools=self._get_tools_list(),
|
||||
force_use_tool=self.force_use_tool,
|
||||
files=self.latest_query_files,
|
||||
tool_call_info=[],
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
yield from self._get_response(
|
||||
[llm_call],
|
||||
check_for_tool_call=not tool_call_made,
|
||||
previously_used_tool=tool_used,
|
||||
previous_tool_response=tool_response,
|
||||
)
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -276,26 +377,32 @@ class Answer:
|
||||
using_tool_calling_llm=self.using_tool_calling_llm,
|
||||
)
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in self._get_response([llm_call]):
|
||||
processed_stream.append(processed_packet)
|
||||
if (
|
||||
isinstance(processed_packet, StreamStopInfo)
|
||||
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
|
||||
):
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
self.processing_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self._processed_stream = self.processing_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if not self._processed_stream and not self.current_streamed_output:
|
||||
return ""
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
|
||||
@@ -80,5 +80,7 @@ class LLMResponseHandlerManager:
|
||||
yield from self.tool_handler.handle_response_part(None, all_messages)
|
||||
yield from self.answer_handler.handle_response_part(None, all_messages)
|
||||
|
||||
def next_llm_call(self, llm_call: LLMCall) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call)
|
||||
def next_llm_call(
|
||||
self, llm_call: LLMCall, tool_call_made: bool = False
|
||||
) -> LLMCall | None:
|
||||
return self.tool_handler.next_llm_call(llm_call, tool_call_made)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
@@ -36,7 +37,10 @@ def default_build_system_message(
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
previous_tool_call_count: int = 0,
|
||||
) -> HumanMessage:
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
@@ -45,10 +49,16 @@ def default_build_user_message(
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
if previous_tool_call_count > 0:
|
||||
user_prompt = (
|
||||
f"You have already generated the above so do not call a tool if not necessary. "
|
||||
f"Remember the query is: `{user_prompt}`"
|
||||
)
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
)
|
||||
|
||||
return user_msg
|
||||
|
||||
|
||||
@@ -87,6 +97,30 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
|
||||
self.new_messages_and_token_cnts: list[tuple[BaseMessage, int]] = []
|
||||
self.task_reminder: tuple[HumanMessage, int] | None = None
|
||||
|
||||
def update_task_reminder(self, task_reminder: HumanMessage) -> None:
|
||||
token_count = check_message_tokens(
|
||||
task_reminder, self.llm_tokenizer_encode_func
|
||||
)
|
||||
self.task_reminder = (task_reminder, token_count)
|
||||
|
||||
def build_next_prompter(
|
||||
self, question: str, llm_answer: str, task_reminder: str | None = None
|
||||
):
|
||||
# Append the AI's previous response
|
||||
self.append_message(AIMessage(content=llm_answer))
|
||||
# Add a new user message prompting the assistant to continue
|
||||
self.append_message(
|
||||
HumanMessage(
|
||||
content=(
|
||||
f"If your previous responses did not fully answer the original query: '{question}', "
|
||||
"please continue and complete the answer. Only add information if the original question "
|
||||
"wasn't fully addressed. Use any necessary tools to provide a comprehensive response. "
|
||||
"If the original query was already completely fulfilled, do NOT call a tool."
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
def update_system_prompt(self, system_message: SystemMessage | None) -> None:
|
||||
if not system_message:
|
||||
@@ -132,6 +166,8 @@ class AnswerPromptBuilder:
|
||||
|
||||
if self.new_messages_and_token_cnts:
|
||||
final_messages_with_tokens.extend(self.new_messages_and_token_cnts)
|
||||
if self.task_reminder:
|
||||
final_messages_with_tokens.append(self.task_reminder)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -173,7 +173,9 @@ class ToolResponseHandler:
|
||||
|
||||
return
|
||||
|
||||
def next_llm_call(self, current_llm_call: LLMCall) -> LLMCall | None:
|
||||
def next_llm_call(
|
||||
self, current_llm_call: LLMCall, tool_call_made: bool
|
||||
) -> LLMCall | None:
|
||||
if (
|
||||
self.tool_runner is None
|
||||
or self.tool_call_summary is None
|
||||
@@ -191,7 +193,9 @@ class ToolResponseHandler:
|
||||
)
|
||||
return LLMCall(
|
||||
prompt_builder=new_prompt_builder,
|
||||
tools=[], # for now, only allow one tool call per response
|
||||
tools=self.tools
|
||||
if not tool_call_made
|
||||
else [], # for now, only allow one tool call per response
|
||||
force_use_tool=ForceUseTool(
|
||||
force_use=False,
|
||||
tool_name="",
|
||||
|
||||
@@ -80,3 +80,15 @@ class Tool(abc.ABC):
|
||||
using_tool_calling_llm: bool,
|
||||
) -> "AnswerPromptBuilder":
|
||||
raise NotImplementedError
|
||||
|
||||
# This is the prompt builder that is used when the tool call AND LLM response has been updated
|
||||
# and we need to build the next prompt (for LLM calling tools)
|
||||
# @abc.abstractmethod
|
||||
|
||||
def build_prompt_after_tool_call(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
query: str,
|
||||
llm_answer: str,
|
||||
) -> "AnswerPromptBuilder":
|
||||
pass
|
||||
|
||||
@@ -4,6 +4,8 @@ from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from litellm import image_generation # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -22,6 +24,9 @@ from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool_implementations.images.prompt import (
|
||||
build_image_generation_user_prompt,
|
||||
)
|
||||
from danswer.tools.tool_implementations.images.prompt import (
|
||||
build_image_generation_user_task_prompt,
|
||||
)
|
||||
from danswer.utils.headers import build_llm_extra_headers
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -297,3 +302,49 @@ class ImageGenerationTool(Tool):
|
||||
)
|
||||
|
||||
return prompt_builder
|
||||
|
||||
def build_prompt_after_tool_call(
|
||||
self,
|
||||
prompt_builder: "AnswerPromptBuilder",
|
||||
query: str,
|
||||
llm_answer: str,
|
||||
tool_responses: "ToolResponse",
|
||||
) -> "AnswerPromptBuilder":
|
||||
# Append the assistant's previous response to the message history
|
||||
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], tool_responses.response
|
||||
)
|
||||
|
||||
if img_generation_response is None:
|
||||
raise ValueError("No image generation response found")
|
||||
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
|
||||
# Build a user message that includes the images generated
|
||||
user_message = build_image_generation_user_task_prompt(
|
||||
img_urls=img_urls,
|
||||
)
|
||||
prompt_builder.update_user_prompt(HumanMessage(content=query))
|
||||
|
||||
# Update the user prompt with the new message containing images
|
||||
prompt_builder.append_message(user_message)
|
||||
|
||||
prompt_builder.append_message(
|
||||
AIMessage(
|
||||
content=f"The images I generated can be described as the following: {llm_answer}"
|
||||
)
|
||||
)
|
||||
|
||||
# Append a new user message reminding the assistant of the original query and what remains to be done
|
||||
prompt_builder.update_task_reminder(
|
||||
HumanMessage(
|
||||
content=f"Reminder: the original request was: '{query}'.\n\n"
|
||||
"You generated the above images as part of this request. "
|
||||
"If any parts have not been fulfilled, please proceed to complete them using the appropriate tools. "
|
||||
"If the original request has been fulfilled with the prior messages,"
|
||||
"you can provide a final summary and DO NOT call a tool."
|
||||
)
|
||||
)
|
||||
|
||||
return prompt_builder
|
||||
|
||||
@@ -10,6 +10,11 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or
|
||||
"""
|
||||
|
||||
|
||||
IMG_GENERATION_USER_PROMPT = """
|
||||
These are the IMAGES you generated
|
||||
"""
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str, img_urls: list[str] | None = None
|
||||
) -> HumanMessage:
|
||||
@@ -19,3 +24,14 @@ def build_image_generation_user_prompt(
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def build_image_generation_user_task_prompt(
|
||||
img_urls: list[str] | None = None,
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_USER_PROMPT,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -131,8 +131,6 @@ export function ChatPage({
|
||||
|
||||
const {
|
||||
chatSessions,
|
||||
availableSources,
|
||||
availableDocumentSets,
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
@@ -2173,6 +2171,25 @@ export function ChatPage({
|
||||
) {
|
||||
return <></>;
|
||||
}
|
||||
const mostRecentNonAIParent = messageHistory
|
||||
.slice(0, i)
|
||||
.reverse()
|
||||
.find((msg) => msg.type !== "assistant");
|
||||
|
||||
const hasChildMessage =
|
||||
message.latestChildMessageId !== null &&
|
||||
message.latestChildMessageId !== undefined;
|
||||
const childMessage = hasChildMessage
|
||||
? messageMap.get(
|
||||
message.latestChildMessageId!
|
||||
)
|
||||
: null;
|
||||
|
||||
const hasParentAI =
|
||||
parentMessage?.type == "assistant";
|
||||
const hasChildAI =
|
||||
childMessage?.type == "assistant";
|
||||
|
||||
return (
|
||||
<div
|
||||
id={`message-${message.messageId}`}
|
||||
@@ -2184,6 +2201,9 @@ export function ChatPage({
|
||||
}
|
||||
>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
hasChildAI={hasChildAI}
|
||||
hasParentAI={hasParentAI}
|
||||
continueGenerating={
|
||||
i == messageHistory.length - 1 &&
|
||||
currentCanContinue()
|
||||
@@ -2193,7 +2213,7 @@ export function ChatPage({
|
||||
overriddenModel={message.overridden_model}
|
||||
regenerate={createRegenerator({
|
||||
messageId: message.messageId,
|
||||
parentMessage: parentMessage!,
|
||||
parentMessage: mostRecentNonAIParent!,
|
||||
})}
|
||||
otherMessagesCanSwitchTo={
|
||||
parentMessage?.childrenMessageIds || []
|
||||
@@ -2340,6 +2360,7 @@ export function ChatPage({
|
||||
return (
|
||||
<div key={messageReactComponentKey}>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
currentPersona={liveAssistant}
|
||||
messageId={message.messageId}
|
||||
content={
|
||||
@@ -2382,6 +2403,7 @@ export function ChatPage({
|
||||
key={`${messageHistory.length}-${chatSessionIdRef.current}`}
|
||||
>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
key={-3}
|
||||
currentPersona={liveAssistant}
|
||||
alternativeAssistant={
|
||||
@@ -2406,6 +2428,7 @@ export function ChatPage({
|
||||
{loadingError && (
|
||||
<div key={-1}>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
currentPersona={liveAssistant}
|
||||
messageId={-1}
|
||||
content={
|
||||
|
||||
@@ -144,3 +144,10 @@ export interface StreamingError {
|
||||
error: string;
|
||||
stack_trace: string;
|
||||
}
|
||||
|
||||
export interface ImageGenerationResult {
|
||||
revised_prompt: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export type ImageGenerationResults = ImageGenerationResult[];
|
||||
|
||||
@@ -1,29 +1,73 @@
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import React, { memo } from "react";
|
||||
import { IMAGE_GENERATION_TOOL_NAME } from "../tools/constants";
|
||||
|
||||
export const MemoizedLink = memo((props: any) => {
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { SearchIcon } from "lucide-react";
|
||||
import DualPromptDisplay from "../tools/ImageCitation";
|
||||
import { PopupSpec, usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { ImageGenerationResults, ToolCallFinalResult } from "../interfaces";
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
);
|
||||
} else if (value?.toString().startsWith("[")) {
|
||||
return <Citation link={rest?.href}>{rest.children}</Citation>;
|
||||
} else {
|
||||
return (
|
||||
<a
|
||||
onMouseDown={() =>
|
||||
rest.href ? window.open(rest.href, "_blank") : undefined
|
||||
}
|
||||
className="cursor-pointer text-link hover:text-link-hover"
|
||||
>
|
||||
{rest.children}
|
||||
</a>
|
||||
);
|
||||
export const MemoizedLink = memo(
|
||||
({
|
||||
toolCall,
|
||||
setPopup,
|
||||
...props
|
||||
}: {
|
||||
toolCall?: ToolCallFinalResult;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
} & any) => {
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
|
||||
if (value?.toString().startsWith(IMAGE_GENERATION_TOOL_NAME)) {
|
||||
const imageGenerationResult =
|
||||
toolCall?.tool_result as ImageGenerationResults;
|
||||
|
||||
return (
|
||||
<Popover>
|
||||
<PopoverTrigger asChild>
|
||||
<span className="inline-block">
|
||||
<SearchIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
|
||||
</span>
|
||||
</PopoverTrigger>
|
||||
<PopoverContent className="w-96" side="top" align="center">
|
||||
<DualPromptDisplay
|
||||
arg="Prompt"
|
||||
setPopup={setPopup!}
|
||||
prompts={imageGenerationResult.map(
|
||||
(result) => result.revised_prompt
|
||||
)}
|
||||
/>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
);
|
||||
} else if (value?.toString().startsWith("[")) {
|
||||
return <Citation link={rest?.href}>{rest.children}</Citation>;
|
||||
} else {
|
||||
return (
|
||||
<a
|
||||
onMouseDown={() =>
|
||||
rest.href ? window.open(rest.href, "_blank") : undefined
|
||||
}
|
||||
className="cursor-pointer text-link hover:text-link-hover"
|
||||
>
|
||||
{rest.children}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
}
|
||||
});
|
||||
);
|
||||
|
||||
export const MemoizedParagraph = memo(({ ...props }: any) => {
|
||||
return <p {...props} className="text-default" />;
|
||||
|
||||
@@ -62,6 +62,7 @@ import { MemoizedLink, MemoizedParagraph } from "./MemoizedTextComponents";
|
||||
import { extractCodeText } from "./codeUtils";
|
||||
import ToolResult from "../../../components/tools/ToolResult";
|
||||
import CsvContent from "../../../components/tools/CSVContent";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -153,6 +154,9 @@ function FileDisplay({
|
||||
}
|
||||
|
||||
export const AIMessage = ({
|
||||
setPopup,
|
||||
hasChildAI,
|
||||
hasParentAI,
|
||||
regenerate,
|
||||
overriddenModel,
|
||||
continueGenerating,
|
||||
@@ -179,6 +183,9 @@ export const AIMessage = ({
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
}: {
|
||||
setPopup?: (popupSpec: PopupSpec | null) => void;
|
||||
hasChildAI?: boolean;
|
||||
hasParentAI?: boolean;
|
||||
shared?: boolean;
|
||||
isActive?: boolean;
|
||||
continueGenerating?: () => void;
|
||||
@@ -227,6 +234,13 @@ export const AIMessage = ({
|
||||
return content;
|
||||
}
|
||||
}
|
||||
if (
|
||||
isComplete &&
|
||||
toolCall?.tool_result &&
|
||||
toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME
|
||||
) {
|
||||
return content + ` [${toolCall.tool_name}]()`;
|
||||
}
|
||||
|
||||
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
|
||||
};
|
||||
@@ -296,7 +310,9 @@ export const AIMessage = ({
|
||||
|
||||
const markdownComponents = useMemo(
|
||||
() => ({
|
||||
a: MemoizedLink,
|
||||
a: (props: any) => (
|
||||
<MemoizedLink {...props} toolCall={toolCall} setPopup={setPopup} />
|
||||
),
|
||||
p: MemoizedParagraph,
|
||||
code: ({ node, className, children, ...props }: any) => {
|
||||
const codeText = extractCodeText(
|
||||
@@ -312,7 +328,7 @@ export const AIMessage = ({
|
||||
);
|
||||
},
|
||||
}),
|
||||
[finalContent]
|
||||
[finalContent, toolCall]
|
||||
);
|
||||
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
@@ -338,7 +354,7 @@ export const AIMessage = ({
|
||||
<div
|
||||
id="danswer-ai-message"
|
||||
ref={trackedElementRef}
|
||||
className={"py-5 ml-4 px-5 relative flex "}
|
||||
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
|
||||
>
|
||||
<div
|
||||
className={`mx-auto ${
|
||||
@@ -347,10 +363,14 @@ export const AIMessage = ({
|
||||
>
|
||||
<div className={`desktop:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
|
||||
<div className="flex">
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
{!hasParentAI ? (
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
) : (
|
||||
<div className="w-6" />
|
||||
)}
|
||||
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
@@ -514,7 +534,8 @@ export const AIMessage = ({
|
||||
)}
|
||||
</div>
|
||||
|
||||
{handleFeedback &&
|
||||
{!hasChildAI &&
|
||||
handleFeedback &&
|
||||
(isActive ? (
|
||||
<div
|
||||
className={`
|
||||
|
||||
@@ -17,6 +17,7 @@ import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
|
||||
function BackToDanswerButton() {
|
||||
const router = useRouter();
|
||||
@@ -41,6 +42,8 @@ export function SharedChatDisplay({
|
||||
persona: Persona;
|
||||
}) {
|
||||
const [isReady, setIsReady] = useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
useEffect(() => {
|
||||
Prism.highlightAll();
|
||||
setIsReady(true);
|
||||
@@ -64,6 +67,7 @@ export function SharedChatDisplay({
|
||||
|
||||
return (
|
||||
<div className="w-full h-[100dvh] overflow-hidden">
|
||||
{popup}
|
||||
<div className="flex max-h-full overflow-hidden pb-[72px]">
|
||||
<div className="flex w-full overflow-hidden overflow-y-scroll">
|
||||
<div className="w-full h-full flex-col flex max-w-message-max mx-auto">
|
||||
@@ -93,6 +97,7 @@ export function SharedChatDisplay({
|
||||
} else {
|
||||
return (
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
shared
|
||||
currentPersona={persona}
|
||||
key={message.messageId}
|
||||
|
||||
92
web/src/app/chat/tools/ImageCitation.tsx
Normal file
92
web/src/app/chat/tools/ImageCitation.tsx
Normal file
@@ -0,0 +1,92 @@
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { CopyIcon } from "@/components/icons/icons";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import React, { useState } from "react";
|
||||
import { FiCheck } from "react-icons/fi";
|
||||
|
||||
interface PromptSectionProps {
|
||||
prompt: string;
|
||||
arg: string;
|
||||
index: number;
|
||||
copied: number | null;
|
||||
onCopy: (text: string, index: number) => void;
|
||||
}
|
||||
|
||||
const PromptSection: React.FC<PromptSectionProps> = ({
|
||||
prompt,
|
||||
arg,
|
||||
index,
|
||||
copied,
|
||||
onCopy,
|
||||
}) => (
|
||||
<div className="w-full p-2 rounded-lg">
|
||||
<h2 className="text-lg font-semibold mb-2">
|
||||
{arg} {index + 1}
|
||||
</h2>
|
||||
<p className="line-clamp-6 text-sm text-gray-800">{prompt}</p>
|
||||
<button
|
||||
onMouseDown={() => onCopy(prompt, index)}
|
||||
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
|
||||
>
|
||||
{copied === index ? (
|
||||
<>
|
||||
<FiCheck className="mr-2" size={16} />
|
||||
Copied!
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<CopyIcon className="mr-2" size={16} />
|
||||
Copy
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
|
||||
interface DualPromptDisplayProps {
|
||||
prompts: string[];
|
||||
arg: string;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
}
|
||||
|
||||
const DualPromptDisplay: React.FC<DualPromptDisplayProps> = ({
|
||||
prompts,
|
||||
arg,
|
||||
setPopup,
|
||||
}) => {
|
||||
const [copied, setCopied] = useState<number | null>(null);
|
||||
|
||||
const copyToClipboard = (text: string, index: number) => {
|
||||
navigator.clipboard
|
||||
.writeText(text)
|
||||
.then(() => {
|
||||
setPopup({ message: "Copied to clipboard", type: "success" });
|
||||
setCopied(index);
|
||||
setTimeout(() => setCopied(null), 2000);
|
||||
})
|
||||
.catch(() => {
|
||||
setPopup({ message: "Failed to copy", type: "error" });
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="w-full bg-inverted mx-auto rounded-lg">
|
||||
<div className="flex flex-col gap-x-4">
|
||||
{prompts.map((prompt, index) => (
|
||||
<React.Fragment key={index}>
|
||||
{index > 0 && <Separator />}
|
||||
<PromptSection
|
||||
prompt={prompt}
|
||||
arg={arg}
|
||||
index={index}
|
||||
copied={copied}
|
||||
onCopy={copyToClipboard}
|
||||
/>
|
||||
</React.Fragment>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default DualPromptDisplay;
|
||||
Reference in New Issue
Block a user