forked from github/onyx
Compare commits
18 Commits
main
...
final_sequ
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1285b2f4d4 | ||
|
|
842628771b | ||
|
|
7a9d5bd92e | ||
|
|
4f3b513ccb | ||
|
|
cd454dd780 | ||
|
|
9140ee99cb | ||
|
|
a64f27c895 | ||
|
|
fdf5611a35 | ||
|
|
c4f483d100 | ||
|
|
fc28c6b9e1 | ||
|
|
33e25dbd8b | ||
|
|
659e8cb69e | ||
|
|
681175e9c3 | ||
|
|
de18ec7ea4 | ||
|
|
9edbb0806d | ||
|
|
63d10e7482 | ||
|
|
ff6a15b5af | ||
|
|
49397e8a86 |
@@ -0,0 +1,65 @@
|
||||
"""single tool call per message
|
||||
|
||||
|
||||
Revision ID: 4e8e7ae58189
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-09-09 10:07:58.008838
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4e8e7ae58189"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the new column
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_tool_call",
|
||||
"chat_message",
|
||||
"tool_call",
|
||||
["tool_call_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Migrate existing data
|
||||
op.execute(
|
||||
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
|
||||
)
|
||||
|
||||
# Drop the old relationship
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "message_id")
|
||||
|
||||
# Add a unique constraint to ensure one-to-one relationship
|
||||
op.create_unique_constraint(
|
||||
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old column
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
|
||||
)
|
||||
|
||||
# Migrate data back
|
||||
op.execute(
|
||||
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
|
||||
)
|
||||
|
||||
# Drop the new column
|
||||
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "tool_call_id")
|
||||
@@ -48,6 +48,8 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
NEW_RESPONSE = "new_response"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
|
||||
@@ -18,6 +18,8 @@ from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@@ -617,6 +619,11 @@ def stream_chat_message_objects(
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
@@ -662,86 +669,168 @@ def stream_chat_message_objects(
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
yielded_message_id_info = True
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
if isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
db_citations = None
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
if reference_db_search_docs:
|
||||
db_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
# Saving Gen AI answer and responding with message info
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=cast(
|
||||
QADocsResponse, qa_docs_response
|
||||
).rephrased_query
|
||||
if qa_docs_response is not None
|
||||
else None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=cast(MessageSpecificCitations, db_citations).citation_map
|
||||
if db_citations is not None
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message.id
|
||||
if user_message is not None
|
||||
else gen_ai_response_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yielded_message_id_info = False
|
||||
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
reference_db_search_docs = None
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
if not yielded_message_id_info:
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=gen_ai_response_message.id,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
yielded_message_id_info = True
|
||||
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(
|
||||
CustomToolCallSummary, packet.response
|
||||
)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@@ -767,11 +856,8 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
if answer.llm_answer == "":
|
||||
return
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
@@ -786,16 +872,14 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
tool_call=ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
else [],
|
||||
else None,
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
|
||||
@@ -178,8 +178,14 @@ def delete_search_doc_message_relationship(
|
||||
|
||||
|
||||
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
|
||||
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
|
||||
)
|
||||
if chat_message and chat_message.tool_call_id:
|
||||
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message.tool_call_id = None
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -388,7 +394,7 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -474,7 +480,7 @@ def create_new_chat_message(
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
tool_call: ToolCall | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
@@ -494,7 +500,7 @@ def create_new_chat_message(
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_calls = tool_calls if tool_calls else []
|
||||
existing_message.tool_call = tool_call if tool_call else None
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
@@ -513,7 +519,7 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
tool_call=tool_call if tool_call else None,
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
@@ -747,14 +753,13 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
@@ -854,10 +854,8 @@ class ToolCall(Base):
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
"ChatMessage", back_populates="tool_call"
|
||||
)
|
||||
|
||||
|
||||
@@ -984,9 +982,14 @@ class ChatMessage(Base):
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
|
||||
tool_call_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("tool_call.id"), nullable=True
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
|
||||
)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
@@ -68,6 +69,7 @@ from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MAX_TOOL_CALLS
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -161,6 +163,10 @@ class Answer:
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
self.final_context_docs: list = []
|
||||
self.current_streamed_output: list = []
|
||||
self.processing_stream: list = []
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
) -> None:
|
||||
@@ -196,128 +202,213 @@ class Answer:
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
tool_calls = 0
|
||||
initiated = False
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
content="",
|
||||
)
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
else:
|
||||
# if tool calling is supported, first try the raw message
|
||||
# to see if we don't need to use any tools
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
|
||||
if initiated:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
initiated = True
|
||||
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield cast(str, message.content)
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get("stop")
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.CONTEXT_LENGTH
|
||||
)
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
if not tool_call_chunk:
|
||||
return # no tool call needed
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
|
||||
else:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
default_build_user_message(
|
||||
self.question,
|
||||
self.prompt_config,
|
||||
self.latest_query_files,
|
||||
tool_calls,
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
existing_message = ""
|
||||
|
||||
return
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
if len(existing_message) > 0:
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.NEW_RESPONSE
|
||||
)
|
||||
existing_message = ""
|
||||
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled or tool_calls > 0:
|
||||
return
|
||||
|
||||
existing_message += cast(str, message.content)
|
||||
yield cast(str, message.content)
|
||||
if (
|
||||
message.additional_kwargs.get("usage_metadata", {}).get(
|
||||
"stop"
|
||||
)
|
||||
== "length"
|
||||
):
|
||||
yield StreamStopInfo(
|
||||
stop_reason=StreamStopReason.CONTEXT_LENGTH
|
||||
)
|
||||
|
||||
if not tool_call_chunk:
|
||||
logger.info("Skipped tool call but generated message")
|
||||
return
|
||||
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool_calls += 1
|
||||
known_tools_by_name = [
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
yield tool_kickoff
|
||||
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
# Update message history with tool call and response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=self.question,
|
||||
message_type=MessageType.USER,
|
||||
token_count=len(self.llm_tokenizer.encode(self.question)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=str(tool_call_request),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(str(tool_call_request))
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message="\n".join(str(response) for response in tool_responses),
|
||||
message_type=MessageType.SYSTEM,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(
|
||||
"\n".join(str(response) for response in tool_responses)
|
||||
)
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
response_content = ""
|
||||
for content in self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None
|
||||
# tools=[tool.tool_definition() for tool in self.tools],
|
||||
):
|
||||
if isinstance(content, str):
|
||||
response_content += content
|
||||
yield content
|
||||
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=response_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(self.llm_tokenizer.encode(response_content)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
# This method processes the LLM stream and yields the content or stop information
|
||||
def _process_llm_stream(
|
||||
@@ -346,139 +437,197 @@ class Answer:
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
tool_calls = 0
|
||||
initiated = False
|
||||
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
|
||||
if initiated:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
initiated = True
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{self.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
return
|
||||
tool_calls += 1
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
|
||||
yield response
|
||||
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final_result = tool_runner.tool_final_result()
|
||||
yield final_result
|
||||
|
||||
# Update message history
|
||||
self.message_history.extend(
|
||||
[
|
||||
PreviousMessage(
|
||||
message=str(self.question),
|
||||
message_type=MessageType.USER,
|
||||
token_count=len(self.llm_tokenizer.encode(str(self.question))),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
PreviousMessage(
|
||||
message=f"Tool used: {tool.name}",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(f"Tool used: {tool.name}")
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
PreviousMessage(
|
||||
message=str(final_result),
|
||||
message_type=MessageType.SYSTEM,
|
||||
token_count=len(self.llm_tokenizer.encode(str(final_result))),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
return
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
response_content = ""
|
||||
for content in self._process_llm_stream(prompt=prompt, tools=None):
|
||||
if isinstance(content, str):
|
||||
response_content += content
|
||||
yield content
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
|
||||
yield response
|
||||
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=img_urls,
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=response_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(self.llm_tokenizer.encode(response_content)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final = tool_runner.tool_final_result()
|
||||
|
||||
yield final
|
||||
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
yield from self._process_llm_stream(prompt=prompt, tools=None)
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -495,6 +644,8 @@ class Answer:
|
||||
else self._raw_output_for_non_explicit_tool_calling_llms()
|
||||
)
|
||||
|
||||
self.processing_stream = []
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
|
||||
) -> AnswerStream:
|
||||
@@ -535,56 +686,70 @@ class Answer:
|
||||
|
||||
yield message
|
||||
else:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
stream_stop_info = None
|
||||
new_kickoff = None
|
||||
|
||||
stream_stop_info = None
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
nonlocal new_kickoff
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
yield cast(str, item)
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
if isinstance(item, ToolCallKickoff):
|
||||
new_kickoff = item
|
||||
return
|
||||
else:
|
||||
yield cast(str, item)
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
|
||||
# handle new tool call (continuation of message)
|
||||
if new_kickoff:
|
||||
yield new_kickoff
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
if (
|
||||
isinstance(processed_packet, StreamStopInfo)
|
||||
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
|
||||
):
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
self.processing_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self._processed_stream = self.processing_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if not self._processed_stream and not self.current_streamed_output:
|
||||
return ""
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
for packet in self.current_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -36,7 +36,10 @@ def default_build_system_message(
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
previous_tool_calls: int = 0,
|
||||
) -> HumanMessage:
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
@@ -45,6 +48,12 @@ def default_build_user_message(
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
if previous_tool_calls > 0:
|
||||
user_prompt = (
|
||||
f"You have already generated the above so do not call a tool if not necessary. "
|
||||
f"Remember the query is: `{user_prompt}`"
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
|
||||
@@ -112,7 +112,7 @@ def translate_danswer_msg_to_langchain(
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
return SystemMessage(content=content)
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
|
||||
@@ -281,14 +281,17 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
def is_disconnected_sync() -> bool:
|
||||
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
||||
try:
|
||||
return not future.result(timeout=0.01)
|
||||
result = not future.result(timeout=0.01)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Asyncio timed out")
|
||||
logger.error("Asyncio timed out while checking client connection")
|
||||
return True
|
||||
except asyncio.CancelledError:
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.critical(
|
||||
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
|
||||
f"An unexpected error occurred with the disconnect check coroutine: {error_msg}"
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
@@ -178,7 +178,7 @@ class ChatMessageDetail(BaseModel):
|
||||
chat_session_id: int | None = None
|
||||
citations: dict[int, int] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
|
||||
@@ -4,7 +4,7 @@ from danswer.llm.utils import build_content_with_imgs
|
||||
|
||||
|
||||
IMG_GENERATION_SUMMARY_PROMPT = """
|
||||
You have just created the attached images in response to the following query: "{query}".
|
||||
You have just created the most recent attached images in response to the following query: "{query}".
|
||||
|
||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
@@ -21,6 +21,8 @@ CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
|
||||
# Tool call configs
|
||||
MAX_TOOL_CALLS = 3
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
||||
44
web/src/app/chat/AIMessageSequenceUtils.ts
Normal file
44
web/src/app/chat/AIMessageSequenceUtils.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
// This module handles AI message sequences - consecutive AI messages that are streamed
|
||||
// separately but represent a single logical message. These utilities are used for
|
||||
// processing and displaying such sequences in the chat interface.
|
||||
|
||||
import { Message } from "@/app/chat/interfaces";
|
||||
import { DanswerDocument } from "@/lib/search/interfaces";
|
||||
|
||||
// Retrieves the consecutive AI messages at the end of the message history.
|
||||
// This is useful for combining or processing the latest AI response sequence.
|
||||
export function getConsecutiveAIMessagesAtEnd(
|
||||
messageHistory: Message[]
|
||||
): Message[] {
|
||||
const aiMessages = [];
|
||||
for (let i = messageHistory.length - 1; i >= 0; i--) {
|
||||
if (messageHistory[i]?.type === "assistant") {
|
||||
aiMessages.unshift(messageHistory[i]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return aiMessages;
|
||||
}
|
||||
|
||||
// Extracts unique documents from a sequence of AI messages.
|
||||
// This is used to compile a comprehensive list of referenced documents
|
||||
// across multiple parts of an AI response.
|
||||
export function getUniqueDocumentsFromAIMessages(
|
||||
messages: Message[]
|
||||
): DanswerDocument[] {
|
||||
const uniqueDocumentsMap = new Map<string, DanswerDocument>();
|
||||
|
||||
messages.forEach((message) => {
|
||||
if (message.documents) {
|
||||
message.documents.forEach((doc) => {
|
||||
const uniqueKey = `${doc.document_id}-${doc.chunk_ind}`;
|
||||
if (!uniqueDocumentsMap.has(uniqueKey)) {
|
||||
uniqueDocumentsMap.set(uniqueKey, doc);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return Array.from(uniqueDocumentsMap.values());
|
||||
}
|
||||
@@ -101,6 +101,8 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { SEARCH_TOOL_NAME } from "./tools/constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
|
||||
import { Button } from "@tremor/react";
|
||||
import dynamic from "next/dynamic";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -133,7 +135,6 @@ export function ChatPage({
|
||||
} = useChatContext();
|
||||
|
||||
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
|
||||
|
||||
const { user, refreshUser, isLoadingUser } = useUser();
|
||||
|
||||
// chat session
|
||||
@@ -248,13 +249,13 @@ export function ChatPage({
|
||||
if (
|
||||
lastMessage &&
|
||||
lastMessage.type === "assistant" &&
|
||||
lastMessage.toolCalls[0] &&
|
||||
lastMessage.toolCalls[0].tool_result === undefined
|
||||
lastMessage.toolCall &&
|
||||
lastMessage.toolCall.tool_result === undefined
|
||||
) {
|
||||
const newCompleteMessageMap = new Map(
|
||||
currentMessageMap(completeMessageDetail)
|
||||
);
|
||||
const updatedMessage = { ...lastMessage, toolCalls: [] };
|
||||
const updatedMessage = { ...lastMessage, toolCall: null };
|
||||
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
|
||||
updateCompleteMessageDetail(currentSession, newCompleteMessageMap);
|
||||
}
|
||||
@@ -483,7 +484,7 @@ export function ChatPage({
|
||||
message: "",
|
||||
type: "system",
|
||||
files: [],
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: null,
|
||||
childrenMessageIds: [firstMessageId],
|
||||
latestChildMessageId: firstMessageId,
|
||||
@@ -510,6 +511,7 @@ export function ChatPage({
|
||||
}
|
||||
newCompleteMessageMap.set(message.messageId, message);
|
||||
});
|
||||
|
||||
// if specified, make these new message the latest of the current message chain
|
||||
if (makeLatestChildMessage) {
|
||||
const currentMessageChain = buildLatestMessageChain(
|
||||
@@ -1044,8 +1046,6 @@ export function ChatPage({
|
||||
resetInputBar();
|
||||
let messageUpdates: Message[] | null = null;
|
||||
|
||||
let answer = "";
|
||||
|
||||
let stopReason: StreamStopReason | null = null;
|
||||
let query: string | null = null;
|
||||
let retrievalType: RetrievalType =
|
||||
@@ -1058,12 +1058,14 @@ export function ChatPage({
|
||||
let stackTrace: string | null = null;
|
||||
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCalls: ToolCallMetadata[] = [];
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
|
||||
let initialFetchDetails: null | {
|
||||
user_message_id: number;
|
||||
assistant_message_id: number;
|
||||
frozenMessageMap: Map<number, Message>;
|
||||
initialDynamicParentMessage: Message;
|
||||
initialDynamicAssistantMessage: Message;
|
||||
} = null;
|
||||
|
||||
try {
|
||||
@@ -1122,7 +1124,16 @@ export function ChatPage({
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
};
|
||||
|
||||
let updateFn = (messages: Message[]) => {
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
chatSessionId: currChatSessionId,
|
||||
});
|
||||
};
|
||||
|
||||
await delay(50);
|
||||
let dynamicParentMessage: Message | null = null;
|
||||
let dynamicAssistantMessage: Message | null = null;
|
||||
while (!stack.isComplete || !stack.isEmpty()) {
|
||||
await delay(0.5);
|
||||
|
||||
@@ -1156,12 +1167,12 @@ export function ChatPage({
|
||||
messageUpdates = [
|
||||
{
|
||||
messageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
? regenerationRequest?.messageId
|
||||
: user_message_id,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
];
|
||||
@@ -1176,22 +1187,109 @@ export function ChatPage({
|
||||
});
|
||||
}
|
||||
|
||||
const { messageMap: currentFrozenMessageMap } =
|
||||
let { messageMap: currentFrozenMessageMap } =
|
||||
upsertToCompleteMessageMap({
|
||||
messages: messageUpdates,
|
||||
chatSessionId: currChatSessionId,
|
||||
});
|
||||
|
||||
const frozenMessageMap = currentFrozenMessageMap;
|
||||
let frozenMessageMap = currentFrozenMessageMap;
|
||||
regenerationRequest?.parentMessage;
|
||||
let initialDynamicParentMessage: Message = regenerationRequest
|
||||
? regenerationRequest?.parentMessage
|
||||
: {
|
||||
messageId: user_message_id!,
|
||||
message: "",
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCall: null,
|
||||
parentMessageId: error ? null : lastSuccessfulMessageId,
|
||||
childrenMessageIds: [assistant_message_id!],
|
||||
latestChildMessageId: -100,
|
||||
};
|
||||
|
||||
let initialDynamicAssistantMessage: Message = {
|
||||
messageId: assistant_message_id!,
|
||||
message: "",
|
||||
type: "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: finalMessage?.context_docs?.top_documents || documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
parentMessageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: user_message_id,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
};
|
||||
|
||||
initialFetchDetails = {
|
||||
frozenMessageMap,
|
||||
assistant_message_id,
|
||||
user_message_id,
|
||||
initialDynamicParentMessage,
|
||||
initialDynamicAssistantMessage,
|
||||
};
|
||||
|
||||
resetRegenerationState();
|
||||
} else {
|
||||
const { user_message_id, frozenMessageMap } = initialFetchDetails;
|
||||
let {
|
||||
initialDynamicParentMessage,
|
||||
initialDynamicAssistantMessage,
|
||||
user_message_id,
|
||||
frozenMessageMap,
|
||||
} = initialFetchDetails;
|
||||
|
||||
if (
|
||||
dynamicParentMessage === null &&
|
||||
dynamicAssistantMessage === null
|
||||
) {
|
||||
dynamicParentMessage = initialDynamicParentMessage;
|
||||
dynamicAssistantMessage = initialDynamicAssistantMessage;
|
||||
|
||||
dynamicParentMessage.message = currMessage;
|
||||
}
|
||||
|
||||
if (!dynamicAssistantMessage || !dynamicParentMessage) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "user_message_id")) {
|
||||
let newParentMessageId = dynamicParentMessage.messageId;
|
||||
const messageResponseIDInfo = packet as MessageResponseIDInfo;
|
||||
|
||||
for (const key in dynamicAssistantMessage) {
|
||||
(dynamicParentMessage as Record<string, any>)[key] = (
|
||||
dynamicAssistantMessage as Record<string, any>
|
||||
)[key];
|
||||
}
|
||||
|
||||
dynamicParentMessage.parentMessageId = newParentMessageId;
|
||||
dynamicParentMessage.latestChildMessageId =
|
||||
messageResponseIDInfo.reserved_assistant_message_id;
|
||||
dynamicParentMessage.childrenMessageIds = [
|
||||
messageResponseIDInfo.reserved_assistant_message_id,
|
||||
];
|
||||
|
||||
dynamicParentMessage.messageId =
|
||||
messageResponseIDInfo.user_message_id!;
|
||||
dynamicAssistantMessage = {
|
||||
messageId: messageResponseIDInfo.reserved_assistant_message_id,
|
||||
type: "assistant",
|
||||
message: "",
|
||||
documents: [],
|
||||
retrievalType: undefined,
|
||||
toolCall: null,
|
||||
files: [],
|
||||
parentMessageId: dynamicParentMessage.messageId,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: null,
|
||||
};
|
||||
}
|
||||
|
||||
setChatState((prevState) => {
|
||||
if (prevState.get(chatSessionIdRef.current!) === "loading") {
|
||||
@@ -1204,37 +1302,37 @@ export function ChatPage({
|
||||
});
|
||||
|
||||
if (Object.hasOwn(packet, "answer_piece")) {
|
||||
answer += (packet as AnswerPiecePacket).answer_piece;
|
||||
dynamicAssistantMessage.message += (
|
||||
packet as AnswerPiecePacket
|
||||
).answer_piece;
|
||||
} else if (Object.hasOwn(packet, "top_documents")) {
|
||||
documents = (packet as DocumentsResponse).top_documents;
|
||||
dynamicAssistantMessage.documents = (
|
||||
packet as DocumentsResponse
|
||||
).top_documents;
|
||||
dynamicAssistantMessage.retrievalType = RetrievalType.Search;
|
||||
retrievalType = RetrievalType.Search;
|
||||
if (documents && documents.length > 0) {
|
||||
// point to the latest message (we don't know the messageId yet, which is why
|
||||
// we have to use -1)
|
||||
setSelectedMessageForDocDisplay(user_message_id);
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "tool_name")) {
|
||||
toolCalls = [
|
||||
{
|
||||
tool_name: (packet as ToolCallMetadata).tool_name,
|
||||
tool_args: (packet as ToolCallMetadata).tool_args,
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
},
|
||||
];
|
||||
dynamicAssistantMessage.toolCall = {
|
||||
tool_name: (packet as ToolCallMetadata).tool_name,
|
||||
tool_args: (packet as ToolCallMetadata).tool_args,
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
};
|
||||
if (
|
||||
!toolCalls[0].tool_result ||
|
||||
toolCalls[0].tool_result == undefined
|
||||
dynamicAssistantMessage.toolCall.tool_name === SEARCH_TOOL_NAME
|
||||
) {
|
||||
dynamicAssistantMessage.query =
|
||||
dynamicAssistantMessage.toolCall.tool_args.query;
|
||||
}
|
||||
|
||||
if (
|
||||
!dynamicAssistantMessage.toolCall ||
|
||||
!dynamicAssistantMessage.toolCall.tool_result ||
|
||||
dynamicAssistantMessage.toolCall.tool_result == undefined
|
||||
) {
|
||||
updateChatState("toolBuilding", frozenSessionId);
|
||||
} else {
|
||||
updateChatState("streaming", frozenSessionId);
|
||||
}
|
||||
|
||||
// This will be consolidated in upcoming tool calls udpate,
|
||||
// but for now, we need to set query as early as possible
|
||||
if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) {
|
||||
query = toolCalls[0].tool_args["query"];
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "file_ids")) {
|
||||
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
|
||||
(fileId) => {
|
||||
@@ -1244,82 +1342,54 @@ export function ChatPage({
|
||||
};
|
||||
}
|
||||
);
|
||||
dynamicAssistantMessage.files = aiMessageImages;
|
||||
} else if (Object.hasOwn(packet, "error")) {
|
||||
error = (packet as StreamingError).error;
|
||||
stackTrace = (packet as StreamingError).stack_trace;
|
||||
dynamicAssistantMessage.stackTrace = (
|
||||
packet as StreamingError
|
||||
).stack_trace;
|
||||
} else if (Object.hasOwn(packet, "message_id")) {
|
||||
finalMessage = packet as BackendMessage;
|
||||
dynamicAssistantMessage = {
|
||||
...dynamicAssistantMessage,
|
||||
...finalMessage,
|
||||
};
|
||||
} else if (Object.hasOwn(packet, "stop_reason")) {
|
||||
const stop_reason = (packet as StreamStopInfo).stop_reason;
|
||||
|
||||
if (stop_reason === StreamStopReason.CONTEXT_LENGTH) {
|
||||
updateCanContinue(true, frozenSessionId);
|
||||
}
|
||||
}
|
||||
if (!Object.hasOwn(packet, "stop_reason")) {
|
||||
updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = regenerationRequest
|
||||
? new Map([
|
||||
[
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
],
|
||||
[
|
||||
dynamicParentMessage?.messageId,
|
||||
dynamicAssistantMessage?.messageId,
|
||||
],
|
||||
] as [number, number][])
|
||||
: null;
|
||||
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
parentMessage =
|
||||
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
replacementsMap: replacementsMap,
|
||||
completeMessageMapOverride: frozenMessageMap,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
};
|
||||
|
||||
const updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = regenerationRequest
|
||||
? new Map([
|
||||
[
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
],
|
||||
[
|
||||
regenerationRequest?.messageId,
|
||||
initialFetchDetails?.assistant_message_id,
|
||||
],
|
||||
] as [number, number][])
|
||||
: null;
|
||||
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
replacementsMap: replacementsMap,
|
||||
completeMessageMapOverride: frozenMessageMap,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
};
|
||||
|
||||
updateFn([
|
||||
{
|
||||
messageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: initialFetchDetails.user_message_id!,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: error ? null : lastSuccessfulMessageId,
|
||||
childrenMessageIds: [
|
||||
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
|
||||
[]),
|
||||
initialFetchDetails.assistant_message_id!,
|
||||
],
|
||||
latestChildMessageId: initialFetchDetails.assistant_message_id,
|
||||
},
|
||||
{
|
||||
messageId: initialFetchDetails.assistant_message_id!,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents:
|
||||
finalMessage?.context_docs?.top_documents || documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCalls: finalMessage?.tool_calls || toolCalls,
|
||||
parentMessageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: initialFetchDetails.user_message_id,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
},
|
||||
]);
|
||||
let { messageMap } = updateFn([
|
||||
dynamicParentMessage,
|
||||
dynamicAssistantMessage,
|
||||
]);
|
||||
frozenMessageMap = messageMap;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1333,7 +1403,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
{
|
||||
@@ -1343,7 +1413,7 @@ export function ChatPage({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
files: aiMessageImages || [],
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId:
|
||||
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
|
||||
},
|
||||
@@ -1962,9 +2032,8 @@ export function ChatPage({
|
||||
completeMessageDetail
|
||||
);
|
||||
const messageReactComponentKey = `${i}-${currentSessionId()}`;
|
||||
const parentMessage = message.parentMessageId
|
||||
? messageMap.get(message.parentMessageId)
|
||||
: null;
|
||||
const parentMessage =
|
||||
i > 1 ? messageHistory[i - 1] : null;
|
||||
if (message.type === "user") {
|
||||
if (
|
||||
(currentSessionChatState == "loading" &&
|
||||
@@ -2055,6 +2124,25 @@ export function ChatPage({
|
||||
) {
|
||||
return <></>;
|
||||
}
|
||||
const mostRecentNonAIParent = messageHistory
|
||||
.slice(0, i)
|
||||
.reverse()
|
||||
.find((msg) => msg.type !== "assistant");
|
||||
|
||||
const hasChildMessage =
|
||||
message.latestChildMessageId !== null &&
|
||||
message.latestChildMessageId !== undefined;
|
||||
const childMessage = hasChildMessage
|
||||
? messageMap.get(
|
||||
message.latestChildMessageId!
|
||||
)
|
||||
: null;
|
||||
|
||||
const hasParentAI =
|
||||
parentMessage?.type == "assistant";
|
||||
const hasChildAI =
|
||||
childMessage?.type == "assistant";
|
||||
|
||||
return (
|
||||
<div
|
||||
id={`message-${message.messageId}`}
|
||||
@@ -2066,6 +2154,9 @@ export function ChatPage({
|
||||
}
|
||||
>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
hasChildAI={hasChildAI}
|
||||
hasParentAI={hasParentAI}
|
||||
continueGenerating={
|
||||
i == messageHistory.length - 1 &&
|
||||
currentCanContinue()
|
||||
@@ -2075,7 +2166,7 @@ export function ChatPage({
|
||||
overriddenModel={message.overridden_model}
|
||||
regenerate={createRegenerator({
|
||||
messageId: message.messageId,
|
||||
parentMessage: parentMessage!,
|
||||
parentMessage: mostRecentNonAIParent!,
|
||||
})}
|
||||
otherMessagesCanSwitchTo={
|
||||
parentMessage?.childrenMessageIds || []
|
||||
@@ -2112,18 +2203,15 @@ export function ChatPage({
|
||||
}
|
||||
messageId={message.messageId}
|
||||
content={message.message}
|
||||
// content={message.message}
|
||||
files={message.files}
|
||||
query={
|
||||
messageHistory[i]?.query || undefined
|
||||
}
|
||||
personaName={liveAssistant.name}
|
||||
citedDocuments={getCitedDocumentsFromMessage(
|
||||
message
|
||||
)}
|
||||
toolCall={
|
||||
message.toolCalls &&
|
||||
message.toolCalls[0]
|
||||
message.toolCall && message.toolCall
|
||||
}
|
||||
isComplete={
|
||||
i !== messageHistory.length - 1 ||
|
||||
@@ -2147,7 +2235,6 @@ export function ChatPage({
|
||||
])
|
||||
}
|
||||
handleSearchQueryEdit={
|
||||
i === messageHistory.length - 1 &&
|
||||
currentSessionChatState == "input"
|
||||
? (newQuery) => {
|
||||
if (!previousMessage) {
|
||||
@@ -2231,7 +2318,6 @@ export function ChatPage({
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
messageId={message.messageId}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{message.message}
|
||||
@@ -2279,7 +2365,6 @@ export function ChatPage({
|
||||
alternativeAssistant
|
||||
}
|
||||
messageId={null}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<div
|
||||
key={"Generating"}
|
||||
@@ -2299,7 +2384,6 @@ export function ChatPage({
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
messageId={-1}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{loadingError}
|
||||
|
||||
@@ -85,7 +85,7 @@ export interface Message {
|
||||
documents?: DanswerDocument[] | null;
|
||||
citations?: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
toolCalls: ToolCallMetadata[];
|
||||
toolCall: ToolCallMetadata | null;
|
||||
// for rebuilding the message tree
|
||||
parentMessageId: number | null;
|
||||
childrenMessageIds?: number[];
|
||||
@@ -120,7 +120,7 @@ export interface BackendMessage {
|
||||
time_sent: string;
|
||||
citations: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
tool_calls: ToolCallFinalResult[];
|
||||
tool_call: ToolCallFinalResult | null;
|
||||
alternate_assistant_id?: number | null;
|
||||
overridden_model?: string;
|
||||
}
|
||||
@@ -143,3 +143,10 @@ export interface StreamingError {
|
||||
error: string;
|
||||
stack_trace: string;
|
||||
}
|
||||
|
||||
export interface ImageGenerationResult {
|
||||
revised_prompt: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export type ImageGenerationResults = ImageGenerationResult[];
|
||||
|
||||
@@ -435,7 +435,7 @@ export function processRawChatHistory(
|
||||
citations: messageInfo?.citations || {},
|
||||
}
|
||||
: {}),
|
||||
toolCalls: messageInfo.tool_calls,
|
||||
toolCall: messageInfo.tool_call,
|
||||
parentMessageId: messageInfo.parent_message,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: messageInfo.latest_child_message,
|
||||
@@ -479,6 +479,7 @@ export function buildLatestMessageChain(
|
||||
let currMessage: Message | null = rootMessage;
|
||||
while (currMessage) {
|
||||
finalMessageList.push(currMessage);
|
||||
|
||||
const childMessageNumber = currMessage.latestChildMessageId;
|
||||
if (childMessageNumber && messageMap.has(childMessageNumber)) {
|
||||
currMessage = messageMap.get(childMessageNumber) as Message;
|
||||
|
||||
@@ -8,25 +8,22 @@ import {
|
||||
FiGlobe,
|
||||
} from "react-icons/fi";
|
||||
import { FeedbackType } from "../types";
|
||||
import {
|
||||
Dispatch,
|
||||
SetStateAction,
|
||||
useContext,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import {
|
||||
DanswerDocument,
|
||||
FilteredDanswerDocument,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { SearchSummary } from "./SearchSummary";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { SkippedSearch } from "./SkippedSearch";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { CopyButton } from "@/components/CopyButton";
|
||||
import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
|
||||
import {
|
||||
ChatFileType,
|
||||
FileDescriptor,
|
||||
ImageGenerationResults,
|
||||
ToolCallMetadata,
|
||||
} from "../interfaces";
|
||||
import {
|
||||
IMAGE_GENERATION_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -44,13 +41,11 @@ import "./custom-code-styles.css";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
|
||||
import {
|
||||
ThumbsUpIcon,
|
||||
ThumbsDownIcon,
|
||||
LikeFeedback,
|
||||
DislikeFeedback,
|
||||
ToolCallIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import {
|
||||
CustomTooltip,
|
||||
@@ -59,12 +54,14 @@ import {
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { useMouseTracking } from "./hooks";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
|
||||
import RegenerateOption from "../RegenerateOption";
|
||||
import { LlmOverride } from "@/lib/hooks";
|
||||
import { ContinueGenerating } from "./ContinueMessage";
|
||||
import DualPromptDisplay from "../tools/ImagePromptCitaiton";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { Popover } from "@/components/popover/Popover";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -121,6 +118,8 @@ function FileDisplay({
|
||||
}
|
||||
|
||||
export const AIMessage = ({
|
||||
hasChildAI,
|
||||
hasParentAI,
|
||||
regenerate,
|
||||
overriddenModel,
|
||||
continueGenerating,
|
||||
@@ -134,7 +133,6 @@ export const AIMessage = ({
|
||||
files,
|
||||
selectedDocuments,
|
||||
query,
|
||||
personaName,
|
||||
citedDocuments,
|
||||
toolCall,
|
||||
isComplete,
|
||||
@@ -148,7 +146,10 @@ export const AIMessage = ({
|
||||
currentPersona,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
setPopup,
|
||||
}: {
|
||||
hasChildAI?: boolean;
|
||||
hasParentAI?: boolean;
|
||||
shared?: boolean;
|
||||
isActive?: boolean;
|
||||
continueGenerating?: () => void;
|
||||
@@ -163,9 +164,8 @@ export const AIMessage = ({
|
||||
content: string | JSX.Element;
|
||||
files?: FileDescriptor[];
|
||||
query?: string;
|
||||
personaName?: string;
|
||||
citedDocuments?: [string, DanswerDocument][] | null;
|
||||
toolCall?: ToolCallMetadata;
|
||||
toolCall?: ToolCallMetadata | null;
|
||||
isComplete?: boolean;
|
||||
hasDocs?: boolean;
|
||||
handleFeedback?: (feedbackType: FeedbackType) => void;
|
||||
@@ -176,7 +176,10 @@ export const AIMessage = ({
|
||||
retrievalDisabled?: boolean;
|
||||
overriddenModel?: string;
|
||||
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
setPopup?: (popupSpec: PopupSpec | null) => void;
|
||||
}) => {
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
|
||||
const toolCallGenerating = toolCall && !toolCall.tool_result;
|
||||
const processContent = (content: string | JSX.Element) => {
|
||||
if (typeof content !== "string") {
|
||||
@@ -199,12 +202,20 @@ export const AIMessage = ({
|
||||
return content;
|
||||
}
|
||||
}
|
||||
if (
|
||||
isComplete &&
|
||||
toolCall?.tool_result &&
|
||||
toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME
|
||||
) {
|
||||
return content + ` [${toolCall.tool_name}]()`;
|
||||
}
|
||||
|
||||
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
|
||||
};
|
||||
const finalContent = processContent(content as string);
|
||||
|
||||
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
|
||||
|
||||
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
|
||||
|
||||
const settings = useContext(SettingsContext);
|
||||
@@ -274,39 +285,50 @@ export const AIMessage = ({
|
||||
<div
|
||||
id="danswer-ai-message"
|
||||
ref={trackedElementRef}
|
||||
className={"py-5 ml-4 px-5 relative flex "}
|
||||
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
|
||||
>
|
||||
<div
|
||||
className={`mx-auto ${shared ? "w-full" : "w-[90%]"} max-w-message-max`}
|
||||
>
|
||||
<div className={`desktop:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
|
||||
<div className="flex">
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
{!hasParentAI ? (
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
) : (
|
||||
<div className="w-6" />
|
||||
)}
|
||||
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
<div className="w-full ml-4">
|
||||
<div className="max-w-message-max break-words">
|
||||
{!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME ? (
|
||||
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
|
||||
<>
|
||||
{query !== undefined &&
|
||||
handleShowRetrieved !== undefined &&
|
||||
isCurrentlyShowingRetrieved !== undefined &&
|
||||
!retrievalDisabled && (
|
||||
<div className="mb-1">
|
||||
<SearchSummary
|
||||
docs={docs}
|
||||
filteredDocs={filteredDocs}
|
||||
query={query}
|
||||
finished={toolCall?.tool_result != undefined}
|
||||
hasDocs={hasDocs || false}
|
||||
messageId={messageId}
|
||||
handleShowRetrieved={handleShowRetrieved}
|
||||
finished={
|
||||
toolCall?.tool_result != undefined ||
|
||||
isComplete!
|
||||
}
|
||||
toggleDocumentSelection={
|
||||
toggleDocumentSelection
|
||||
}
|
||||
handleSearchQueryEdit={handleSearchQueryEdit}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{handleForceSearch &&
|
||||
!hasChildAI &&
|
||||
content &&
|
||||
query === undefined &&
|
||||
!hasDocs &&
|
||||
@@ -318,7 +340,7 @@ export const AIMessage = ({
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
) : null}
|
||||
)}
|
||||
|
||||
{toolCall &&
|
||||
!TOOLS_WITH_CUSTOM_HANDLING.includes(
|
||||
@@ -371,6 +393,50 @@ export const AIMessage = ({
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
|
||||
if (
|
||||
value
|
||||
?.toString()
|
||||
.startsWith(IMAGE_GENERATION_TOOL_NAME)
|
||||
) {
|
||||
const imageGenerationResult =
|
||||
toolCall?.tool_result as ImageGenerationResults;
|
||||
|
||||
return (
|
||||
<Popover
|
||||
open={isPopoverOpen}
|
||||
onOpenChange={
|
||||
() => null
|
||||
// setIsPopoverOpen(isPopoverOpen => !isPopoverOpen)
|
||||
} // only allow closing from the icon
|
||||
content={
|
||||
<button
|
||||
onMouseDown={() => {
|
||||
setIsPopoverOpen(!isPopoverOpen);
|
||||
}}
|
||||
>
|
||||
<ToolCallIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
|
||||
</button>
|
||||
}
|
||||
popover={
|
||||
<DualPromptDisplay
|
||||
arg="Prompt"
|
||||
setPopup={setPopup!}
|
||||
prompt1={
|
||||
imageGenerationResult[0]
|
||||
.revised_prompt
|
||||
}
|
||||
prompt2={
|
||||
imageGenerationResult[1]
|
||||
.revised_prompt
|
||||
}
|
||||
/>
|
||||
}
|
||||
side="top"
|
||||
align="center"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
@@ -378,10 +444,6 @@ export const AIMessage = ({
|
||||
} else if (
|
||||
value?.toString().startsWith("[")
|
||||
) {
|
||||
// for some reason <a> tags cause the onClick to not apply
|
||||
// and the links are unclickable
|
||||
// TODO: fix the fact that you have to double click to follow link
|
||||
// for the first link
|
||||
return (
|
||||
<Citation link={rest?.href}>
|
||||
{rest.children}
|
||||
@@ -428,82 +490,10 @@ export const AIMessage = ({
|
||||
) : isComplete ? null : (
|
||||
<></>
|
||||
)}
|
||||
{isComplete && docs && docs.length > 0 && (
|
||||
<div className="mt-2 -mx-8 w-full mb-4 flex relative">
|
||||
<div className="w-full">
|
||||
<div className="px-8 flex gap-x-2">
|
||||
{!settings?.isMobile &&
|
||||
filteredDocs.length > 0 &&
|
||||
filteredDocs.slice(0, 2).map((doc, ind) => (
|
||||
<div
|
||||
key={doc.document_id}
|
||||
className={`w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 pb-2 pt-1 border-b
|
||||
`}
|
||||
>
|
||||
<a
|
||||
href={doc.link || undefined}
|
||||
target="_blank"
|
||||
className="text-sm flex w-full pt-1 gap-x-1.5 overflow-hidden justify-between font-semibold text-text-700"
|
||||
>
|
||||
<Citation link={doc.link} index={ind + 1} />
|
||||
<p className="shrink truncate ellipsis break-all">
|
||||
{doc.semantic_identifier ||
|
||||
doc.document_id}
|
||||
</p>
|
||||
<div className="ml-auto flex-none">
|
||||
{doc.is_internet ? (
|
||||
<InternetSearchIcon url={doc.link} />
|
||||
) : (
|
||||
<SourceIcon
|
||||
sourceType={doc.source_type}
|
||||
iconSize={18}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</a>
|
||||
<div className="flex overscroll-x-scroll mt-.5">
|
||||
<DocumentMetadataBlock document={doc} />
|
||||
</div>
|
||||
<div className="line-clamp-3 text-xs break-words pt-1">
|
||||
{doc.blurb}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<div
|
||||
onClick={() => {
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={-1}
|
||||
className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b"
|
||||
>
|
||||
<div className="text-sm flex justify-between font-semibold text-text-700">
|
||||
<p className="line-clamp-1">See context</p>
|
||||
<div className="flex gap-x-1">
|
||||
{uniqueSources.map((sourceType, ind) => {
|
||||
return (
|
||||
<div key={ind} className="flex-none">
|
||||
<SourceIcon
|
||||
sourceType={sourceType}
|
||||
iconSize={18}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
<div className="line-clamp-3 text-xs break-words pt-1">
|
||||
See more
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{handleFeedback &&
|
||||
{!hasChildAI &&
|
||||
handleFeedback &&
|
||||
(isActive ? (
|
||||
<div
|
||||
className={`
|
||||
|
||||
@@ -4,9 +4,21 @@ import {
|
||||
} from "@/components/BasicClickable";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { ChevronDownIcon, InfoIcon } from "@/components/icons/icons";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
DanswerDocument,
|
||||
FilteredDanswerDocument,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import { FiCheck, FiEdit2, FiSearch, FiX } from "react-icons/fi";
|
||||
import { DownChevron } from "react-select/dist/declarations/src/components/indicators";
|
||||
|
||||
export function ShowHideDocsButton({
|
||||
messageId,
|
||||
@@ -37,24 +49,31 @@ export function ShowHideDocsButton({
|
||||
|
||||
export function SearchSummary({
|
||||
query,
|
||||
hasDocs,
|
||||
filteredDocs,
|
||||
finished,
|
||||
messageId,
|
||||
handleShowRetrieved,
|
||||
docs,
|
||||
toggleDocumentSelection,
|
||||
handleSearchQueryEdit,
|
||||
}: {
|
||||
toggleDocumentSelection?: () => void;
|
||||
docs?: DanswerDocument[] | null;
|
||||
filteredDocs: FilteredDanswerDocument[];
|
||||
finished: boolean;
|
||||
query: string;
|
||||
hasDocs: boolean;
|
||||
messageId: number | null;
|
||||
handleShowRetrieved: (messageId: number | null) => void;
|
||||
handleSearchQueryEdit?: (query: string) => void;
|
||||
}) {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [finalQuery, setFinalQuery] = useState(query);
|
||||
const [isOverflowed, setIsOverflowed] = useState(false);
|
||||
const searchingForRef = useRef<HTMLDivElement>(null);
|
||||
const editQueryRef = useRef<HTMLInputElement>(null);
|
||||
const [isDropdownOpen, setIsDropdownOpen] = useState(false);
|
||||
const searchingForRef = useRef<HTMLDivElement | null>(null);
|
||||
const editQueryRef = useRef<HTMLInputElement | null>(null);
|
||||
|
||||
const settings = useContext(SettingsContext);
|
||||
|
||||
const toggleDropdown = () => {
|
||||
setIsDropdownOpen(!isDropdownOpen);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const checkOverflow = () => {
|
||||
@@ -68,7 +87,7 @@ export function SearchSummary({
|
||||
};
|
||||
|
||||
checkOverflow();
|
||||
window.addEventListener("resize", checkOverflow); // Recheck on window resize
|
||||
window.addEventListener("resize", checkOverflow);
|
||||
|
||||
return () => window.removeEventListener("resize", checkOverflow);
|
||||
}, []);
|
||||
@@ -86,15 +105,30 @@ export function SearchSummary({
|
||||
}, [query]);
|
||||
|
||||
const searchingForDisplay = (
|
||||
<div className={`flex p-1 rounded ${isOverflowed && "cursor-default"}`}>
|
||||
<FiSearch className="flex-none mr-2 my-auto" size={14} />
|
||||
<div
|
||||
className={`${!finished && "loading-text"}
|
||||
!text-sm !line-clamp-1 !break-all px-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
{finished ? "Searched" : "Searching"} for: <i> {finalQuery}</i>
|
||||
</div>
|
||||
<div
|
||||
className={`flex my-auto items-center ${isOverflowed && "cursor-default"}`}
|
||||
>
|
||||
{finished ? (
|
||||
<>
|
||||
<div
|
||||
onClick={() => {
|
||||
toggleDropdown();
|
||||
}}
|
||||
className={`transition-colors duration-300 group-hover:text-text-toolhover cursor-pointer text-text-toolrun !line-clamp-1 !break-all pr-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
Searched {filteredDocs.length > 0 && filteredDocs.length} document
|
||||
{filteredDocs.length != 1 && "s"} for {query}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div
|
||||
className={`loading-text !text-sm !line-clamp-1 !break-all px-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
Searching for: <i> {finalQuery}</i>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -145,43 +179,126 @@ export function SearchSummary({
|
||||
</div>
|
||||
</div>
|
||||
) : null;
|
||||
const SearchBlock = ({ doc, ind }: { doc: DanswerDocument; ind: number }) => {
|
||||
return (
|
||||
<div
|
||||
onClick={() => {
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={doc.document_id}
|
||||
className={`flex items-start gap-3 px-4 py-3 text-token-text-secondary ${ind == 0 && "rounded-t-xl"} hover:bg-background-100 group relative text-sm`}
|
||||
>
|
||||
<div className="mt-1 scale-[.9] flex-none">
|
||||
{doc.is_internet ? (
|
||||
<InternetSearchIcon url={doc.link} />
|
||||
) : (
|
||||
<SourceIcon sourceType={doc.source_type} iconSize={18} />
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-col">
|
||||
<a
|
||||
href={doc.link}
|
||||
target="_blank"
|
||||
className="line-clamp-1 text-text-900"
|
||||
>
|
||||
<p className="shrink truncate ellipsis break-all ">
|
||||
{doc.semantic_identifier || doc.document_id}
|
||||
</p>
|
||||
<p className="line-clamp-3 text-text-500 break-words">
|
||||
{doc.blurb}
|
||||
</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
{isEditing ? (
|
||||
editInput
|
||||
) : (
|
||||
<>
|
||||
<div className="text-sm">
|
||||
{isOverflowed ? (
|
||||
<HoverPopup
|
||||
mainContent={searchingForDisplay}
|
||||
popupContent={
|
||||
<div>
|
||||
<b>Full query:</b>{" "}
|
||||
<div className="mt-1 italic">{query}</div>
|
||||
</div>
|
||||
}
|
||||
direction="top"
|
||||
<>
|
||||
<div className="flex gap-x-2 group">
|
||||
{isEditing ? (
|
||||
editInput
|
||||
) : (
|
||||
<>
|
||||
<div className="my-auto text-sm">
|
||||
{isOverflowed ? (
|
||||
<HoverPopup
|
||||
mainContent={searchingForDisplay}
|
||||
popupContent={
|
||||
<div>
|
||||
<b>Full query:</b>{" "}
|
||||
<div className="mt-1 italic">{query}</div>
|
||||
</div>
|
||||
}
|
||||
direction="top"
|
||||
/>
|
||||
) : (
|
||||
searchingForDisplay
|
||||
)}
|
||||
</div>
|
||||
|
||||
<button
|
||||
className="my-auto invisible group-hover:visible transition-all duration-300 rounded"
|
||||
onClick={toggleDropdown}
|
||||
>
|
||||
<ChevronDownIcon
|
||||
className={`transform transition-transform ${isDropdownOpen ? "rotate-180" : ""}`}
|
||||
/>
|
||||
</button>
|
||||
|
||||
{handleSearchQueryEdit ? (
|
||||
<Tooltip delayDuration={1000} content={"Edit Search"}>
|
||||
<button
|
||||
className="my-auto invisible group-hover:visible transition-all duration-300 cursor-pointer rounded"
|
||||
onClick={() => {
|
||||
setIsEditing(true);
|
||||
}}
|
||||
>
|
||||
<FiEdit2 />
|
||||
</button>
|
||||
</Tooltip>
|
||||
) : (
|
||||
searchingForDisplay
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
{handleSearchQueryEdit && (
|
||||
<Tooltip delayDuration={1000} content={"Edit Search"}>
|
||||
<button
|
||||
className="my-auto hover:bg-hover p-1.5 rounded"
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isDropdownOpen && docs && docs.length > 0 && (
|
||||
<div
|
||||
className={`mt-2 -mx-8 w-full mb-4 flex relative transition-all duration-300 ${isDropdownOpen ? "opacity-100 max-h-[1000px]" : "opacity-0 max-h-0"}`}
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="mx-8 flex rounded max-h-[500px] overflow-y-scroll rounded-lg border-1.5 border divide-y divider-y-1.5 divider-y-border border-border flex-col gap-x-4">
|
||||
{!settings?.isMobile &&
|
||||
filteredDocs.length > 0 &&
|
||||
filteredDocs.map((doc, ind) => (
|
||||
<SearchBlock key={ind} doc={doc} ind={ind} />
|
||||
))}
|
||||
|
||||
<div
|
||||
onClick={() => {
|
||||
setIsEditing(true);
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={-1}
|
||||
className="cursor-pointer w-full flex transition-all duration-500 hover:bg-background-100 py-3 border-b"
|
||||
>
|
||||
<FiEdit2 />
|
||||
</button>
|
||||
</Tooltip>
|
||||
)}
|
||||
</>
|
||||
<div key={0} className="px-3 invisible scale-[.9] flex-none">
|
||||
<SourceIcon sourceType={"file"} iconSize={18} />
|
||||
</div>
|
||||
<div className="text-sm flex justify-between text-text-900">
|
||||
<p className="line-clamp-1">See context</p>
|
||||
<div className="flex gap-x-1"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -101,7 +101,6 @@ export function SharedChatDisplay({
|
||||
messageId={message.messageId}
|
||||
content={message.message}
|
||||
files={message.files || []}
|
||||
personaName={chatSession.persona_name}
|
||||
citedDocuments={getCitedDocumentsFromMessage(message)}
|
||||
isComplete
|
||||
/>
|
||||
|
||||
83
web/src/app/chat/tools/ImagePromptCitaiton.tsx
Normal file
83
web/src/app/chat/tools/ImagePromptCitaiton.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { CopyIcon } from "@/components/icons/icons";
|
||||
import { Divider } from "@tremor/react";
|
||||
import React, { forwardRef, useState } from "react";
|
||||
import { FiCheck } from "react-icons/fi";
|
||||
|
||||
interface PromptDisplayProps {
|
||||
prompt1: string;
|
||||
prompt2?: string;
|
||||
arg: string;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
}
|
||||
|
||||
const DualPromptDisplay = forwardRef<HTMLDivElement, PromptDisplayProps>(
|
||||
({ prompt1, prompt2, setPopup, arg }, ref) => {
|
||||
const [copied, setCopied] = useState<number | null>(null);
|
||||
|
||||
const copyToClipboard = (text: string, index: number) => {
|
||||
navigator.clipboard
|
||||
.writeText(text)
|
||||
.then(() => {
|
||||
setPopup({ message: "Copied to clipboard", type: "success" });
|
||||
setCopied(index);
|
||||
setTimeout(() => setCopied(null), 2000); // Reset copy status after 2 seconds
|
||||
})
|
||||
.catch((err) => {
|
||||
setPopup({ message: "Failed to copy", type: "error" });
|
||||
});
|
||||
};
|
||||
|
||||
const PromptSection = ({
|
||||
copied,
|
||||
prompt,
|
||||
index,
|
||||
}: {
|
||||
copied: number | null;
|
||||
prompt: string;
|
||||
index: number;
|
||||
}) => (
|
||||
<div className="w-full p-2 rounded-lg">
|
||||
<h2 className="text-lg font-semibold mb-2">
|
||||
{arg} {index + 1}
|
||||
</h2>
|
||||
|
||||
<p className="line-clamp-6 text-sm text-gray-800">{prompt}</p>
|
||||
|
||||
<button
|
||||
onMouseDown={() => copyToClipboard(prompt, index)}
|
||||
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
|
||||
>
|
||||
{copied == index ? (
|
||||
<>
|
||||
<FiCheck className="mr-2" size={16} />
|
||||
Copied!
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<CopyIcon className="mr-2" size={16} />
|
||||
Copy
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="w-[400px] bg-inverted mx-auto p-6 rounded-lg shadow-lg">
|
||||
<div className="flex flex-col gap-x-4">
|
||||
<PromptSection copied={copied} prompt={prompt1} index={0} />
|
||||
{prompt2 && (
|
||||
<>
|
||||
<Divider />
|
||||
<PromptSection copied={copied} prompt={prompt2} index={1} />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
DualPromptDisplay.displayName = "DualPromptDisplay";
|
||||
export default DualPromptDisplay;
|
||||
@@ -145,6 +145,19 @@ export function ClientLayout({
|
||||
),
|
||||
link: "/admin/tools",
|
||||
},
|
||||
...(enableEnterprise
|
||||
? [
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<ClipboardIcon size={18} />
|
||||
<div className="ml-1">Standard Answers</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/standard-answer",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
|
||||
@@ -2811,3 +2811,40 @@ export const WindowsIcon = ({
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const ToolCallIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 19 15"
|
||||
fill="none"
|
||||
>
|
||||
<path
|
||||
d="M4.42 0.75H2.8625H2.75C1.64543 0.75 0.75 1.64543 0.75 2.75V11.65C0.75 12.7546 1.64543 13.65 2.75 13.65H2.8625C2.8625 13.65 2.8625 13.65 2.8625 13.65C2.8625 13.65 4.00751 13.65 4.42 13.65M13.98 13.65H15.5375H15.65C16.7546 13.65 17.65 12.7546 17.65 11.65V2.75C17.65 1.64543 16.7546 0.75 15.65 0.75H15.5375H13.98"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M5.55283 4.21963C5.25993 3.92674 4.78506 3.92674 4.49217 4.21963C4.19927 4.51252 4.19927 4.9874 4.49217 5.28029L6.36184 7.14996L4.49217 9.01963C4.19927 9.31252 4.19927 9.7874 4.49217 10.0803C4.78506 10.3732 5.25993 10.3732 5.55283 10.0803L7.95283 7.68029C8.24572 7.3874 8.24572 6.91252 7.95283 6.61963L5.55283 4.21963Z"
|
||||
fill="currentColor"
|
||||
stroke="currentColor"
|
||||
strokeWidth="0.2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M9.77753 8.75003C9.3357 8.75003 8.97753 9.10821 8.97753 9.55003C8.97753 9.99186 9.3357 10.35 9.77753 10.35H13.2775C13.7194 10.35 14.0775 9.99186 14.0775 9.55003C14.0775 9.10821 13.7194 8.75003 13.2775 8.75003H9.77753Z"
|
||||
fill="currentColor"
|
||||
stroke="currentColor"
|
||||
strokeWidth="0.1"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import {
|
||||
DanswerDocument,
|
||||
DocumentRelevance,
|
||||
Relevance,
|
||||
SearchDanswerDocument,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { DocumentFeedbackBlock } from "./DocumentFeedbackBlock";
|
||||
@@ -12,11 +11,10 @@ import { PopupSpec } from "../admin/connectors/Popup";
|
||||
import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge";
|
||||
import { SourceIcon } from "../SourceIcon";
|
||||
import { MetadataBadge } from "../MetadataBadge";
|
||||
import { BookIcon, CheckmarkIcon, LightBulbIcon, XIcon } from "../icons/icons";
|
||||
import { BookIcon, LightBulbIcon } from "../icons/icons";
|
||||
|
||||
import { FaStar } from "react-icons/fa";
|
||||
import { FiTag } from "react-icons/fi";
|
||||
import { DISABLE_LLM_DOC_RELEVANCE } from "@/lib/constants";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { CustomTooltip, TooltipGroup } from "../tooltip/CustomTooltip";
|
||||
import { WarningCircle } from "@phosphor-icons/react";
|
||||
|
||||
@@ -22,6 +22,8 @@ export interface AnswerPiecePacket {
|
||||
export enum StreamStopReason {
|
||||
CONTEXT_LENGTH = "CONTEXT_LENGTH",
|
||||
CANCELLED = "CANCELLED",
|
||||
FINISHED = "FINISHED",
|
||||
NEW_RESPONSE = "NEW_RESPONSE",
|
||||
}
|
||||
|
||||
export interface StreamStopInfo {
|
||||
|
||||
Reference in New Issue
Block a user