forked from github/onyx
Compare commits
50 Commits
main
...
concurrent
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4eb53ce56f | ||
|
|
2fc84ed63e | ||
|
|
722d5e6e54 | ||
|
|
14c30d2e4d | ||
|
|
6abad2fdd3 | ||
|
|
4691e736f6 | ||
|
|
5a826a527f | ||
|
|
f92d31df70 | ||
|
|
1eb786897a | ||
|
|
72471f9e1d | ||
|
|
49c335d06a | ||
|
|
fda06b7739 | ||
|
|
00d44e31b3 | ||
|
|
2a42c1dd18 | ||
|
|
05cd25043e | ||
|
|
abebff50bb | ||
|
|
0a7e672832 | ||
|
|
221ab9134c | ||
|
|
f7134202b6 | ||
|
|
bea11dc3aa | ||
|
|
374b798071 | ||
|
|
6a2e3edfcd | ||
|
|
2ef1731e32 | ||
|
|
7d4d7a5f5d | ||
|
|
ea2f9cf625 | ||
|
|
97dc9c5e31 | ||
|
|
249bcd46d9 | ||
|
|
f29b727bc7 | ||
|
|
31fb6c0753 | ||
|
|
a45e72c298 | ||
|
|
157548817c | ||
|
|
d9396f77d1 | ||
|
|
7bae6bbf8f | ||
|
|
1d535769ed | ||
|
|
8584a81fe2 | ||
|
|
5f4ac19928 | ||
|
|
d898e4f738 | ||
|
|
19412f0aa0 | ||
|
|
c338de30fd | ||
|
|
edfde621b9 | ||
|
|
9306abf911 | ||
|
|
70d885b621 | ||
|
|
53bea4f859 | ||
|
|
a79d734d96 | ||
|
|
25cd7de147 | ||
|
|
ab2916c807 | ||
|
|
96112f1f95 | ||
|
|
54502b32d3 | ||
|
|
9431e6c06c | ||
|
|
f18571d580 |
59
backend/alembic/versions/eb690a089310_migrate_tool_calls.py
Normal file
59
backend/alembic/versions/eb690a089310_migrate_tool_calls.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""migrate tool calls
|
||||
|
||||
Revision ID: eb690a089310
|
||||
Revises: ee3f4b47fad5
|
||||
Create Date: 2024-08-04 17:07:47.533051
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "eb690a089310"
|
||||
down_revision = "ee3f4b47fad5"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the new column
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_tool_call",
|
||||
"chat_message",
|
||||
"tool_call",
|
||||
["tool_call_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Migrate existing data
|
||||
op.execute(
|
||||
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
|
||||
)
|
||||
|
||||
# Drop the old relationship
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "message_id")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old column
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
|
||||
)
|
||||
|
||||
# Migrate data back
|
||||
op.execute(
|
||||
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
|
||||
)
|
||||
|
||||
# Drop the new column
|
||||
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "tool_call_id")
|
||||
@@ -0,0 +1,28 @@
|
||||
"""Added alternate model to chat message
|
||||
|
||||
Revision ID: ee3f4b47fad5
|
||||
Revises: 4a951134c801
|
||||
Create Date: 2024-08-12 00:11:50.915845
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ee3f4b47fad5"
|
||||
down_revision = "4a951134c801"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_message",
|
||||
sa.Column("alternate_model", sa.String(length=255), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_message", "alternate_model")
|
||||
@@ -36,9 +36,11 @@ def create_chat_chain(
|
||||
chat_session_id: int,
|
||||
db_session: Session,
|
||||
prefetch_tool_calls: bool = True,
|
||||
parent_id: int | None = None,
|
||||
) -> tuple[ChatMessage, list[ChatMessage]]:
|
||||
"""Build the linear chain of messages without including the root message"""
|
||||
mainline_messages: list[ChatMessage] = []
|
||||
|
||||
all_chat_messages = get_chat_messages_by_session(
|
||||
chat_session_id=chat_session_id,
|
||||
user_id=None,
|
||||
@@ -60,7 +62,7 @@ def create_chat_chain(
|
||||
current_message: ChatMessage | None = root_message
|
||||
while current_message is not None:
|
||||
child_msg = current_message.latest_child_message
|
||||
if not child_msg:
|
||||
if not child_msg or (parent_id and current_message.id == parent_id):
|
||||
break
|
||||
current_message = id_to_msg.get(child_msg)
|
||||
|
||||
|
||||
@@ -64,6 +64,10 @@ class DocumentRelevance(BaseModel):
|
||||
relevance_summaries: dict[str, RelevanceAnalysis]
|
||||
|
||||
|
||||
class Delimiter(BaseModel):
|
||||
delimiter: bool
|
||||
|
||||
|
||||
class DanswerAnswerPiece(BaseModel):
|
||||
# A small piece of a complete answer. Used for streaming back answers.
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
@@ -76,6 +80,11 @@ class CitationInfo(BaseModel):
|
||||
document_id: str
|
||||
|
||||
|
||||
class MessageResponseIDInfo(BaseModel):
|
||||
user_message_id: int | None
|
||||
reserved_assistant_message_id: int
|
||||
|
||||
|
||||
class StreamingError(BaseModel):
|
||||
error: str
|
||||
stack_trace: str | None = None
|
||||
@@ -137,6 +146,7 @@ AnswerQuestionPossibleReturn = (
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| Delimiter
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -9,8 +9,10 @@ from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import Delimiter
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
@@ -27,6 +29,7 @@ from danswer.db.chat import get_chat_session_by_id
|
||||
from danswer.db.chat import get_db_search_doc_by_id
|
||||
from danswer.db.chat import get_doc_query_identifiers_from_model
|
||||
from danswer.db.chat import get_or_create_root_message
|
||||
from danswer.db.chat import reserve_message_id
|
||||
from danswer.db.chat import translate_db_message_to_chat_message_detail
|
||||
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
|
||||
from danswer.db.embedding_model import get_current_db_embedding_model
|
||||
@@ -88,7 +91,7 @@ from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_runner import ToolCallMetadata
|
||||
from danswer.tools.utils import compute_all_tool_tokens
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
@@ -241,6 +244,8 @@ ChatPacket = (
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| MessageResponseIDInfo
|
||||
| Delimiter
|
||||
)
|
||||
ChatPacketStream = Iterator[ChatPacket]
|
||||
|
||||
@@ -256,9 +261,9 @@ def stream_chat_message_objects(
|
||||
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
|
||||
# if specified, uses the last user message and does not create a new user message based
|
||||
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
|
||||
# user message (e.g. this can only be used for the chat-seeding flow).
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> ChatPacketStream:
|
||||
"""Streams in order:
|
||||
1. [conditional] Retrieved documents if a search needs to be run
|
||||
@@ -342,7 +347,15 @@ def stream_chat_message_objects(
|
||||
parent_message = root_message
|
||||
|
||||
user_message = None
|
||||
if not use_existing_user_message:
|
||||
|
||||
if new_msg_req.regenerate:
|
||||
final_msg, history_msgs = create_chat_chain(
|
||||
parent_id=parent_id,
|
||||
chat_session_id=chat_session_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
elif not use_existing_user_message:
|
||||
# Create new message at the right place in the tree and update the parent's child pointer
|
||||
# Don't commit yet until we verify the chat message chain
|
||||
user_message = create_new_chat_message(
|
||||
@@ -451,13 +464,29 @@ def stream_chat_message_objects(
|
||||
use_sections=new_msg_req.chunks_above > 0
|
||||
or new_msg_req.chunks_below > 0,
|
||||
)
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=user_message.id
|
||||
if user_message is not None
|
||||
else parent_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id if user_message else None,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
|
||||
alternate_model = (
|
||||
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
|
||||
)
|
||||
# Cannot determine these without the LLM step or breaking out early
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=final_msg,
|
||||
prompt_id=prompt_id,
|
||||
alternate_model=alternate_model,
|
||||
# message=,
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
@@ -581,9 +610,11 @@ def stream_chat_message_objects(
|
||||
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
|
||||
llm_provider, llm_model_name
|
||||
)
|
||||
tool_has_been_called = False # TODO remove
|
||||
|
||||
# LLM prompt building, response capturing, etc.
|
||||
answer = Answer(
|
||||
is_connected=is_connected,
|
||||
question=final_msg.message,
|
||||
latest_query_files=latest_query_files,
|
||||
answer_style_config=AnswerStyleConfig(
|
||||
@@ -617,8 +648,11 @@ def stream_chat_message_objects(
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
tool_has_been_called = True
|
||||
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
@@ -689,9 +723,76 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
if isinstance(packet, Delimiter):
|
||||
db_citations = None
|
||||
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query
|
||||
if qa_docs_response
|
||||
else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
yield Delimiter(delimiter=True)
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
# message=,
|
||||
# rephrased_query=,
|
||||
# token_count=,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
# error=,
|
||||
# reference_docs=,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallMetadata):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
logger.debug("Reached end of stream")
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(f"Failed to process chat message: {error_msg}")
|
||||
@@ -703,54 +804,56 @@ def stream_chat_message_objects(
|
||||
db_session.rollback()
|
||||
return
|
||||
|
||||
# Post-LLM answer processing
|
||||
try:
|
||||
db_citations = None
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
if not tool_has_been_called:
|
||||
try:
|
||||
db_citations = None
|
||||
if reference_db_search_docs:
|
||||
db_citations = translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=(
|
||||
qa_docs_response.rephrased_query if qa_docs_response else None
|
||||
),
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=db_citations,
|
||||
error=None,
|
||||
tool_call=ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
if tool_result
|
||||
else [],
|
||||
)
|
||||
db_session.commit() # actually save user / assistant message
|
||||
if tool_result
|
||||
else None,
|
||||
)
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
logger.debug("Committing messages")
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
yield msg_detail_response
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
yield StreamingError(error="Failed to parse LLM output")
|
||||
yield msg_detail_response
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.exception(error_msg)
|
||||
|
||||
# Frontend will erase whatever answer and show this instead
|
||||
yield StreamingError(error="Failed to parse LLM output")
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
@@ -759,6 +862,7 @@ def stream_chat_message(
|
||||
user: User | None,
|
||||
use_existing_user_message: bool = False,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
) -> Iterator[str]:
|
||||
with get_session_context_manager() as db_session:
|
||||
objects = stream_chat_message_objects(
|
||||
@@ -767,6 +871,7 @@ def stream_chat_message(
|
||||
db_session=db_session,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
is_connected=is_connected,
|
||||
)
|
||||
for obj in objects:
|
||||
yield get_json_line(obj.dict())
|
||||
|
||||
@@ -42,7 +42,8 @@ prompts:
|
||||
task: >
|
||||
Generate an image based on the user's description.
|
||||
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
Provide a detailed description of the generated image, including key elements, colors, and composition.
|
||||
|
||||
|
||||
If the request is not possible or appropriate, explain why and suggest alternatives.
|
||||
datetime_aware: true
|
||||
|
||||
@@ -13,6 +13,9 @@ DEFAULT_BOOST = 0
|
||||
SESSION_KEY = "session"
|
||||
|
||||
|
||||
# For tool calling
|
||||
MAXIMUM_TOOL_CALL_SEQUENCE = 5
|
||||
|
||||
# For chunking/processing chunks
|
||||
RETURN_SEPARATOR = "\n\r\n"
|
||||
SECTION_SEPARATOR = "\n\n"
|
||||
|
||||
@@ -143,7 +143,7 @@ def handle_standard_answers(
|
||||
parent_message=root_message,
|
||||
prompt_id=prompt.id if prompt else None,
|
||||
message=query_msg.message,
|
||||
token_count=0,
|
||||
token_count=10,
|
||||
message_type=MessageType.USER,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
|
||||
@@ -36,7 +36,7 @@ from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SavedSearchDoc
|
||||
from danswer.search.models import SearchDoc as ServerSearchDoc
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_runner import ToolCallMetadata
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -147,8 +147,14 @@ def delete_search_doc_message_relationship(
|
||||
|
||||
|
||||
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
|
||||
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
|
||||
)
|
||||
if chat_message and chat_message.tool_call_id:
|
||||
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message.tool_call_id = None
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -350,7 +356,7 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -393,6 +399,34 @@ def get_or_create_root_message(
|
||||
return new_root_message
|
||||
|
||||
|
||||
def reserve_message_id(
|
||||
db_session: Session,
|
||||
chat_session_id: int,
|
||||
parent_message: int,
|
||||
message_type: MessageType,
|
||||
) -> int:
|
||||
# Create an empty chat message
|
||||
empty_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message,
|
||||
latest_child_message=None,
|
||||
message="",
|
||||
token_count=0,
|
||||
message_type=message_type,
|
||||
)
|
||||
|
||||
# Add the empty message to the session
|
||||
db_session.add(empty_message)
|
||||
|
||||
# Flush the session to get an ID for the new chat message
|
||||
db_session.flush()
|
||||
|
||||
# Get the ID of the newly created message
|
||||
new_id = empty_message.id
|
||||
|
||||
return new_id
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: int,
|
||||
parent_message: ChatMessage,
|
||||
@@ -408,38 +442,62 @@ def create_new_chat_message(
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
tool_call: ToolCall | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
alternate_model: str | None = None,
|
||||
) -> ChatMessage:
|
||||
new_chat_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message.id,
|
||||
latest_child_message=None,
|
||||
message=message,
|
||||
rephrased_query=rephrased_query,
|
||||
prompt_id=prompt_id,
|
||||
token_count=token_count,
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
)
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
existing_message = db_session.query(ChatMessage).get(reserved_message_id)
|
||||
if existing_message is None:
|
||||
raise ValueError(f"No message found with id {reserved_message_id}")
|
||||
|
||||
existing_message.chat_session_id = chat_session_id
|
||||
existing_message.parent_message = parent_message.id
|
||||
existing_message.message = message
|
||||
existing_message.rephrased_query = rephrased_query
|
||||
existing_message.prompt_id = prompt_id
|
||||
existing_message.token_count = token_count
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_call = tool_call
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.alternate_model = alternate_model
|
||||
|
||||
new_chat_message = existing_message
|
||||
else:
|
||||
# Create new message
|
||||
new_chat_message = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=parent_message.id,
|
||||
latest_child_message=None,
|
||||
message=message,
|
||||
rephrased_query=rephrased_query,
|
||||
prompt_id=prompt_id,
|
||||
token_count=token_count,
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_call=tool_call,
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
alternate_model=alternate_model,
|
||||
)
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
|
||||
if reference_docs:
|
||||
new_chat_message.search_docs = reference_docs
|
||||
|
||||
db_session.add(new_chat_message)
|
||||
|
||||
# Flush the session to get an ID for the new chat message
|
||||
db_session.flush()
|
||||
|
||||
parent_message.latest_child_message = new_chat_message.id
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return new_chat_message
|
||||
|
||||
|
||||
@@ -656,15 +714,15 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
tool_call=ToolCallMetadata(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
alternate_model=chat_message.alternate_model,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -764,10 +764,11 @@ class ToolCall(Base):
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
"ChatMessage",
|
||||
back_populates="tool_call",
|
||||
uselist=False,
|
||||
foreign_keys="ChatMessage.tool_call_id",
|
||||
)
|
||||
|
||||
|
||||
@@ -848,9 +849,13 @@ class ChatMessage(Base):
|
||||
Integer, ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
|
||||
alternate_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
|
||||
message: Mapped[str] = mapped_column(Text)
|
||||
tool_call_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("tool_call.id"), nullable=True
|
||||
)
|
||||
rephrased_query: Mapped[str] = mapped_column(Text, nullable=True)
|
||||
# If None, then there is no answer generation, it's the special case of only
|
||||
# showing the user the retrieved docs
|
||||
@@ -893,9 +898,8 @@ class ChatMessage(Base):
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
|
||||
)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -6,12 +8,19 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DanswerQuotes
|
||||
from danswer.chat.models import Delimiter
|
||||
from danswer.chat.models import ImageGenerationDisplay
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
@@ -51,20 +60,21 @@ from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
|
||||
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
|
||||
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
|
||||
from danswer.tools.search.search_tool import SearchResponseSummary
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.tools.tool_runner import (
|
||||
check_which_tools_should_run_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.tool_runner import ToolCallFinalResult
|
||||
from danswer.tools.tool_runner import ToolCallKickoff
|
||||
from danswer.tools.tool_runner import ToolCallMetadata
|
||||
from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
# DanswerQuotes | CitationInfo | DanswerContexts | ImageGenerationDisplay | CustomToolResponse | StreamingError
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -115,6 +125,8 @@ class Answer:
|
||||
# Returns the full document sections text from the search tool
|
||||
return_contexts: bool = False,
|
||||
skip_gen_ai_answer_generation: bool = False,
|
||||
is_connected: Callable[[], bool] | None = None,
|
||||
tool_uses: dict | None = None,
|
||||
) -> None:
|
||||
if single_message_history and message_history:
|
||||
raise ValueError(
|
||||
@@ -122,6 +134,7 @@ class Answer:
|
||||
)
|
||||
|
||||
self.question = question
|
||||
self.is_connected: Callable[[], bool] | None = is_connected
|
||||
|
||||
self.latest_query_files = latest_query_files or []
|
||||
self.file_id_to_file = {file.file_id: file for file in (files or [])}
|
||||
@@ -132,12 +145,17 @@ class Answer:
|
||||
self.skip_explicit_tool_calling = skip_explicit_tool_calling
|
||||
|
||||
self.message_history = message_history or []
|
||||
self.tool_uses = tool_uses or {}
|
||||
# used for QA flow where we only want to send a single message
|
||||
self.single_message_history = single_message_history
|
||||
|
||||
self.answer_style_config = answer_style_config
|
||||
self.prompt_config = prompt_config
|
||||
|
||||
self.current_streamed_output: list = []
|
||||
self.processing_stream: list = []
|
||||
self.final_context_docs: list = []
|
||||
|
||||
self.llm = llm
|
||||
self.llm_tokenizer = get_tokenizer(
|
||||
provider_type=llm.config.model_provider,
|
||||
@@ -153,6 +171,7 @@ class Answer:
|
||||
|
||||
self._return_contexts = return_contexts
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
@@ -174,6 +193,7 @@ class Answer:
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
elif self.answer_style_config.quotes_config:
|
||||
prompt_builder.update_user_prompt(
|
||||
build_quotes_user_message(
|
||||
@@ -184,11 +204,219 @@ class Answer:
|
||||
)
|
||||
)
|
||||
|
||||
def _raw_output_for_explicit_tool_calling_llms_loop(
|
||||
self,
|
||||
) -> Iterator[
|
||||
str | ToolCallKickoff | ToolResponse | ToolCallMetadata | Delimiter | Any
|
||||
]:
|
||||
count = 1
|
||||
maximum_count = 4
|
||||
while count <= maximum_count:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
count += 1
|
||||
print(f"COUNT IS {count}")
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
else:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
print(f"-----------------------\nThe current prompt is: {prompt}")
|
||||
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
tool_choice="required" if self.force_use_tool.force_use else None,
|
||||
):
|
||||
if isinstance(message, AIMessageChunk) and (
|
||||
message.tool_call_chunks or message.tool_calls
|
||||
):
|
||||
if tool_call_chunk is None:
|
||||
tool_call_chunk = message
|
||||
else:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if tool_call_chunk is None and count != 2:
|
||||
print("Skipping the tool call + message compeltely")
|
||||
return
|
||||
elif message.content:
|
||||
yield cast(str, message.content)
|
||||
|
||||
if not tool_call_chunk:
|
||||
print(
|
||||
"Skipping the tool call but generated message due to lack of existing tool call messages"
|
||||
)
|
||||
return
|
||||
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
|
||||
logger.critical(
|
||||
f"-------------------TOOL CALL REQUESTS ({len(tool_call_requests)})-------------------"
|
||||
)
|
||||
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == tool_call_request["name"]
|
||||
]
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
print("my tool call request is htis")
|
||||
print(tool_args)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
|
||||
tool_responses = list(tool_runner.tool_responses())
|
||||
yield from tool_responses
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
# Update message history with tool call and response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=str(tool_call_request),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=10, # You may want to implement a token counting method
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message="\n".join(str(response) for response in tool_responses),
|
||||
message_type=MessageType.SYSTEM,
|
||||
token_count=10,
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
print("-------------")
|
||||
|
||||
print("NEW PROMPT")
|
||||
print(prompt)
|
||||
print("\n\n\n\n\n\n-------\n\n------")
|
||||
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=self.final_context_docs or [],
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
self.final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
response_stream = process_answer_stream_fn(
|
||||
message_generator_to_string_generator(
|
||||
self.llm.stream(prompt=prompt)
|
||||
)
|
||||
)
|
||||
|
||||
response_content = ""
|
||||
for chunk in response_stream:
|
||||
response_content += (
|
||||
chunk.answer_piece
|
||||
if hasattr(chunk, "answer_piece")
|
||||
and chunk.answer_piece is not None
|
||||
else str(chunk)
|
||||
)
|
||||
yield chunk
|
||||
|
||||
yield "FINAL TOKEN"
|
||||
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=response_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=10,
|
||||
tool_call=None,
|
||||
files=[], # You may want to implement a token counting method
|
||||
)
|
||||
)
|
||||
|
||||
def _raw_output_for_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
|
||||
) -> Iterator[
|
||||
str
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallMetadata
|
||||
| AnswerQuestionPossibleReturn
|
||||
| str
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
# special things we need to keep track of for the SearchTool
|
||||
# search_results: list[
|
||||
# LlmDoc
|
||||
# ] | None = None # raw results that will be displayed to the user
|
||||
# final_context_docs: list[
|
||||
# LlmDoc
|
||||
# ] | None = None # processed docs to feed into the LLM
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
@@ -235,6 +463,8 @@ class Answer:
|
||||
tool_call_chunk += message # type: ignore
|
||||
else:
|
||||
if message.content:
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield cast(str, message.content)
|
||||
|
||||
if not tool_call_chunk:
|
||||
@@ -290,20 +520,49 @@ class Answer:
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
# Streaming response
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
yield from message_generator_to_string_generator(
|
||||
for token in message_generator_to_string_generator(
|
||||
self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
)
|
||||
):
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield token
|
||||
|
||||
return
|
||||
# ADD BACK IN
|
||||
# process_answer_stream_fn = _get_answer_stream_processor(
|
||||
# context_docs=[],
|
||||
# # if doc selection is enabled, then search_results will be None,
|
||||
# # so we need to use the final_context_docs
|
||||
# doc_id_to_rank_map=map_document_id_order([]),
|
||||
# answer_style_configs=self.answer_style_config,
|
||||
# )
|
||||
|
||||
# yield from process_answer_stream_fn(
|
||||
# message_generator_to_string_generator(self.llm.stream(prompt=prompt))
|
||||
# )
|
||||
|
||||
def _raw_output_for_non_explicit_tool_calling_llms(
|
||||
self,
|
||||
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
|
||||
) -> Iterator[
|
||||
str
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallMetadata
|
||||
| AnswerQuestionStreamReturn
|
||||
| DanswerAnswerPiece
|
||||
| DanswerQuotes
|
||||
| CitationInfo
|
||||
| DanswerContexts
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| Delimiter
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
@@ -378,9 +637,13 @@ class Answer:
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from message_generator_to_string_generator(
|
||||
for token in message_generator_to_string_generator(
|
||||
self.llm.stream(prompt=prompt)
|
||||
)
|
||||
):
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield token
|
||||
|
||||
return
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
@@ -430,11 +693,28 @@ class Answer:
|
||||
)
|
||||
)
|
||||
final = tool_runner.tool_final_result()
|
||||
|
||||
yield final
|
||||
|
||||
# Streaming response
|
||||
prompt = prompt_builder.build()
|
||||
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
|
||||
for token in message_generator_to_string_generator(
|
||||
self.llm.stream(prompt=prompt)
|
||||
):
|
||||
if self.is_cancelled:
|
||||
return
|
||||
yield token
|
||||
# Add this back in
|
||||
# process_answer_stream_fn = _get_answer_stream_processor(
|
||||
# context_docs=final_context_documents or [],
|
||||
# # if doc selection is enabled, then search_results will be None,
|
||||
# # so we need to use the final_context_docs
|
||||
# doc_id_to_rank_map=map_document_id_order(final_context_documents or []),
|
||||
# answer_style_configs=self.answer_style_config,
|
||||
# )
|
||||
|
||||
# yield from process_answer_stream_fn(
|
||||
# message_generator_to_string_generator(self.llm.stream(prompt=prompt))
|
||||
# )
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -443,97 +723,102 @@ class Answer:
|
||||
return
|
||||
|
||||
output_generator = (
|
||||
self._raw_output_for_explicit_tool_calling_llms()
|
||||
self._raw_output_for_explicit_tool_calling_llms_loop()
|
||||
if explicit_tool_calling_supported(
|
||||
self.llm.config.model_provider, self.llm.config.model_name
|
||||
)
|
||||
and not self.skip_explicit_tool_calling
|
||||
else self._raw_output_for_non_explicit_tool_calling_llms()
|
||||
)
|
||||
print(f" output generator {output_generator} ")
|
||||
|
||||
self.processing_stream = []
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str],
|
||||
stream: Iterator[
|
||||
str
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallMetadata
|
||||
| Delimiter
|
||||
| Any
|
||||
],
|
||||
) -> AnswerStream:
|
||||
message = None
|
||||
|
||||
# special things we need to keep track of for the SearchTool
|
||||
search_results: list[
|
||||
LlmDoc
|
||||
] | None = None # raw results that will be displayed to the user
|
||||
final_context_docs: list[
|
||||
LlmDoc
|
||||
] | None = None # processed docs to feed into the LLM
|
||||
|
||||
for message in stream:
|
||||
if isinstance(message, ToolCallKickoff) or isinstance(
|
||||
message, ToolCallFinalResult
|
||||
message, ToolCallMetadata
|
||||
):
|
||||
yield message
|
||||
elif isinstance(message, ToolResponse):
|
||||
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
pass
|
||||
# We don't need to run section merging in this flow, this variable is only used
|
||||
# below to specify the ordering of the documents for the purpose of matching
|
||||
# citations to the right search documents. The deduplication logic is more lightweight
|
||||
# there and we don't need to do it twice
|
||||
search_results = [
|
||||
llm_doc_from_inference_section(section)
|
||||
for section in cast(
|
||||
SearchResponseSummary, message.response
|
||||
).top_sections
|
||||
]
|
||||
# search_results = [
|
||||
# llm_doc_from_inference_section(section)
|
||||
# for section in cast(
|
||||
# SearchResponseSummary, message.response
|
||||
# ).top_sections
|
||||
# ]
|
||||
|
||||
elif message.id == FINAL_CONTEXT_DOCUMENTS:
|
||||
final_context_docs = cast(list[LlmDoc], message.response)
|
||||
self.final_context_docs = final_context_docs
|
||||
|
||||
elif (
|
||||
message.id == SEARCH_DOC_CONTENT_ID
|
||||
and not self._return_contexts
|
||||
):
|
||||
continue
|
||||
|
||||
yield message
|
||||
else:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
if message == "FINAL TOKEN":
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
yield Delimiter(delimiter=True)
|
||||
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
elif isinstance(message, str):
|
||||
yield DanswerAnswerPiece(answer_piece=str(message))
|
||||
else:
|
||||
yield message
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
if message:
|
||||
yield cast(str, message)
|
||||
yield from cast(Iterator[str], stream)
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
processed_stream.append(processed_packet)
|
||||
self.processing_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
self._processed_stream = self.processing_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if not self._processed_stream and not self.current_streamed_output:
|
||||
return ""
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
for packet in self.current_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
return citations
|
||||
|
||||
@property
|
||||
def is_cancelled(self) -> bool:
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
if self.is_connected is not None:
|
||||
if not self.is_connected():
|
||||
logger.debug("Answer stream has been cancelled")
|
||||
self._is_cancelled = not self.is_connected()
|
||||
|
||||
return self._is_cancelled
|
||||
|
||||
@@ -16,7 +16,7 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallMetadata
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallMetadata | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
tool_call=ToolCallMetadata(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -133,6 +133,10 @@ class AnswerPromptBuilder:
|
||||
)
|
||||
)
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
response = drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
)
|
||||
for msg in response:
|
||||
print(f"{msg.type} : \t \t ||||||||| {msg.content[:20]}")
|
||||
|
||||
return response
|
||||
|
||||
@@ -9,6 +9,7 @@ from requests import Timeout
|
||||
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.llm.utils import convert_lm_input_to_basic_string
|
||||
@@ -73,6 +74,12 @@ class CustomModelServer(LLM):
|
||||
response_content = json.loads(response.content).get("generated_text", "")
|
||||
return AIMessage(content=response_content)
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, message: PreviousMessage) -> str:
|
||||
# TODO improve / iterate
|
||||
|
||||
return f'I searched for some things! """thigns that I searched for!: {message.message}"""'
|
||||
|
||||
def log_model_configs(self) -> None:
|
||||
logger.debug(f"Custom model at: {self._endpoint}")
|
||||
|
||||
|
||||
14
backend/danswer/llm/temporary.py
Normal file
14
backend/danswer/llm/temporary.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
|
||||
|
||||
def create_previous_message(
|
||||
assistant_content: str, token_count: int
|
||||
) -> PreviousMessage:
|
||||
return PreviousMessage(
|
||||
message=assistant_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=token_count,
|
||||
files=[],
|
||||
tool_call=None,
|
||||
)
|
||||
@@ -10,11 +10,11 @@ import litellm # type: ignore
|
||||
import tiktoken
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
from langchain.schema import AIMessage
|
||||
from langchain.schema import BaseMessage
|
||||
from langchain.schema import HumanMessage
|
||||
from langchain.schema import PromptValue
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from litellm.exceptions import APIConnectionError # type: ignore
|
||||
from litellm.exceptions import APIError # type: ignore
|
||||
@@ -46,6 +46,27 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# def translate_danswer_msg_to_langchain(
|
||||
# msg: Union[ChatMessage, "PreviousMessage"]
|
||||
# ) -> BaseMessage:
|
||||
# files: list[InMemoryChatFile] = []
|
||||
|
||||
# # If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# # attached. Just ignore them for now. Also, OpenAI doesn't allow files to
|
||||
# # be attached to AI messages, so we must remove them
|
||||
# if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
|
||||
# files = msg.files
|
||||
# content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
# if msg.message_type == MessageType.SYSTEM:
|
||||
# print("SYSTE MESAGE")
|
||||
# print(msg.message)
|
||||
# # raise ValueError("System messages are not currently part of history")
|
||||
# if msg.message_type == MessageType.ASSISTANT:
|
||||
# return AIMessage(content=content)
|
||||
# if msg.message_type == MessageType.USER:
|
||||
# return HumanMessage(content=content)
|
||||
|
||||
|
||||
def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
|
||||
error_msg = str(e)
|
||||
@@ -100,21 +121,38 @@ def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
msg: Union[ChatMessage, "PreviousMessage"], token_count: int
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
|
||||
# be attached to AI messages, so we must remove them
|
||||
if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
content = msg.message
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
return SystemMessage(content=content)
|
||||
wrapped_content = ""
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
try:
|
||||
parsed_content = json.loads(content)
|
||||
if (
|
||||
"name" in parsed_content
|
||||
and parsed_content["name"] == "run_image_generation"
|
||||
):
|
||||
wrapped_content += f"I, the AI, am now generating an \
|
||||
image based on the prompt: '{parsed_content['args']['prompt']}'\n"
|
||||
wrapped_content += "[/AI IMAGE GENERATION REQUEST]"
|
||||
elif (
|
||||
"id" in parsed_content
|
||||
and parsed_content["id"] == "image_generation_response"
|
||||
):
|
||||
wrapped_content += "I, the AI, have generated the following image(s) based on the previous request:\n"
|
||||
for img in parsed_content["response"]:
|
||||
wrapped_content += f"- Description: {img['revised_prompt']}\n"
|
||||
wrapped_content += f" Image URL: {img['url']}\n\n"
|
||||
wrapped_content += "[/AI IMAGE GENERATION RESPONSE]"
|
||||
else:
|
||||
wrapped_content = content
|
||||
except json.JSONDecodeError:
|
||||
wrapped_content = content
|
||||
return AIMessage(content=wrapped_content)
|
||||
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
@@ -124,15 +162,103 @@ def translate_danswer_msg_to_langchain(
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
print("message history is")
|
||||
new_history = []
|
||||
assistant_content = None
|
||||
token_count = 1
|
||||
from danswer.llm.temporary import create_previous_message
|
||||
from danswer.tools.tool import ToolRegistry
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
|
||||
for i, msg in enumerate(history):
|
||||
message = cast(ChatMessage, msg)
|
||||
if message.message_type != MessageType.ASSISTANT:
|
||||
if assistant_content is not None:
|
||||
combined_ai_message = create_previous_message(
|
||||
assistant_content, token_count
|
||||
)
|
||||
assistant_content = None
|
||||
new_history.append(combined_ai_message)
|
||||
new_history.append(cast(PreviousMessage, message))
|
||||
continue
|
||||
|
||||
if message.tool_call and message.tool_call.tool_name == "run_image_generation":
|
||||
assistant_content = (assistant_content or "") + ToolRegistry.get_prompt(
|
||||
"run_image_generation", message
|
||||
)
|
||||
|
||||
# assistant_content = assistant_content or "" + f"I generated images with these descriptions! {message.message}"
|
||||
else:
|
||||
assistant_content = message.message
|
||||
|
||||
# TODO make better + fix token counting
|
||||
if assistant_content is not None:
|
||||
combined_ai_message = create_previous_message(assistant_content, token_count)
|
||||
new_history.append(combined_ai_message)
|
||||
|
||||
history = new_history
|
||||
for h in history:
|
||||
print(f"\t\t{h.message_type}: \t\t|| {h.message[:100]}\n\n")
|
||||
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
translate_danswer_msg_to_langchain(msg, 0)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
# summary = "[CONVERSATION SUMMARY]\n"
|
||||
# summary += "The most recent user request may involve generating additional images. "
|
||||
# summary += "I should carefully review the conversation history an
|
||||
# d the latest user request and almost defeinitely GENERATE MORE IMAGES "
|
||||
# summary += "[/CONVERSATION SUMMARY]"
|
||||
# history_basemessages.append(AIMessage(content=summary))
|
||||
# history_token_counts.append(100)
|
||||
# print()
|
||||
for msg in history_basemessages:
|
||||
print(f"{msg.type} : \t \t ||||||||| {msg.content[:20]}")
|
||||
|
||||
# print(f"{msg.type}: {msg.content[:20]}")
|
||||
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
|
||||
# def translate_history_to_basemessages(
|
||||
# history: Union[list[ChatMessage], list["PreviousMessage"]]
|
||||
# ) -> tuple[list[BaseMessage], list[int]]:
|
||||
# history_basemessages = []
|
||||
# history_token_counts = []
|
||||
# image_generation_count = 0
|
||||
|
||||
# for msg in history:
|
||||
# if msg.token_count != 0:
|
||||
# translated_msg = translate_danswer_msg_to_langchain(
|
||||
# msg, image_generation_count
|
||||
# )
|
||||
# if (
|
||||
# isinstance(translated_msg.content, str)
|
||||
# and "[ImageGenerationRe" in translated_msg.content
|
||||
# ):
|
||||
# image_generation_count += 1
|
||||
# history_basemessages.append(translated_msg)
|
||||
# history_token_counts.append(msg.token_count)
|
||||
|
||||
# # Add a generic summary message at the end
|
||||
# summary = "[CONVERSATION SUMMARY]\n"
|
||||
# summary += "The most recent user request may involve generating additional images. "
|
||||
# summary += "I should carefully review the conversation history and the latest user request "
|
||||
# summary += (
|
||||
# "to determine if any new images need to be generated, ensuring I don't repeat "
|
||||
# )
|
||||
# summary += "any image generations that have already been completed.\n"
|
||||
# summary += f"I already generated {image_generation_count} images thus far. I should keep my responses EXTREMELY SHORT"
|
||||
# summary += "[/CONVERSATION SUMMARY]"
|
||||
# history_basemessages.append(AIMessage(content=summary))
|
||||
# history_token_counts.append(100)
|
||||
|
||||
# return history_basemessages, history_token_counts
|
||||
|
||||
|
||||
def _build_content(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
@@ -189,15 +315,15 @@ def build_content_with_imgs(
|
||||
for file in files
|
||||
if file.file_type == "image"
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url,
|
||||
},
|
||||
}
|
||||
for url in img_urls
|
||||
],
|
||||
# + [
|
||||
# {
|
||||
# "type": "image_url",
|
||||
# "image_url": {
|
||||
# "url": url,
|
||||
# },
|
||||
# }
|
||||
# for url in img_urls
|
||||
# ],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ def check_if_need_search_multi_message(
|
||||
return True
|
||||
|
||||
prompt_msgs: list[BaseMessage] = [SystemMessage(content=REQUIRE_SEARCH_SYSTEM_MSG)]
|
||||
prompt_msgs.extend([translate_danswer_msg_to_langchain(msg) for msg in history])
|
||||
prompt_msgs.extend([translate_danswer_msg_to_langchain(msg, 2) for msg in history])
|
||||
|
||||
last_query = query_message.message
|
||||
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import asyncio
|
||||
import io
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@@ -207,8 +210,6 @@ def rename_chat_session(
|
||||
chat_session_id = rename_req.chat_session_id
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
logger.info(f"Received rename request for chat session: {chat_session_id}")
|
||||
|
||||
if name:
|
||||
update_chat_session(
|
||||
db_session=db_session,
|
||||
@@ -271,19 +272,39 @@ def delete_chat_session_by_id(
|
||||
delete_chat_session(user_id, session_id, db_session)
|
||||
|
||||
|
||||
async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
main_loop = asyncio.get_event_loop()
|
||||
|
||||
def is_disconnected_sync() -> bool:
|
||||
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
||||
try:
|
||||
return not future.result(timeout=0.01)
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Asyncio timed out")
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.critical(
|
||||
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
|
||||
)
|
||||
return True
|
||||
|
||||
return is_disconnected_sync
|
||||
|
||||
|
||||
@router.post("/send-message")
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_user),
|
||||
_: None = Depends(check_token_rate_limits),
|
||||
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
|
||||
) -> StreamingResponse:
|
||||
"""This endpoint is both used for all the following purposes:
|
||||
- Sending a new message in the session
|
||||
- Regenerating a message in the session (just send the same one again)
|
||||
- Editing a message (similar to regenerating but sending a different message)
|
||||
- Kicking off a seeded chat session (set `use_existing_user_message`)
|
||||
|
||||
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
|
||||
have already been set as latest"""
|
||||
logger.debug(f"Received new chat message: {chat_message_req.message}")
|
||||
@@ -295,15 +316,26 @@ def handle_new_chat_message(
|
||||
):
|
||||
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
|
||||
|
||||
packets = stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=get_litellm_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
)
|
||||
return StreamingResponse(packets, media_type="application/json")
|
||||
import json
|
||||
|
||||
def stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
for packet in stream_chat_message(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
use_existing_user_message=chat_message_req.use_existing_user_message,
|
||||
litellm_additional_headers=get_litellm_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
is_connected=is_disconnected_func,
|
||||
):
|
||||
yield json.dumps(packet) if isinstance(packet, dict) else packet
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in chat message streaming: {e}")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(stream_generator(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.put("/set-message-as-latest")
|
||||
|
||||
@@ -17,7 +17,7 @@ from danswer.search.models import ChunkContext
|
||||
from danswer.search.models import RetrievalDetails
|
||||
from danswer.search.models import SearchDoc
|
||||
from danswer.search.models import Tag
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallMetadata
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
@@ -97,7 +97,9 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# allows the caller to specify the exact search query they want to use
|
||||
# will disable Query Rewording if specified
|
||||
query_override: str | None = None
|
||||
alternate_model: str | None = None # Added optional string for alternate model
|
||||
|
||||
regenerate: bool | None = None
|
||||
# allows the caller to override the Persona / Prompt
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
@@ -179,11 +181,12 @@ class ChatMessageDetail(BaseModel):
|
||||
message_type: MessageType
|
||||
time_sent: datetime
|
||||
alternate_assistant_id: str | None
|
||||
alternate_model: str | None
|
||||
# Dict mapping citation number to db_doc_id
|
||||
chat_session_id: int | None = None
|
||||
citations: dict[int, int] | None
|
||||
files: list[FileDescriptor]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallMetadata | None
|
||||
|
||||
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().dict(*args, **kwargs) # type: ignore
|
||||
|
||||
@@ -60,6 +60,12 @@ class CustomTool(Tool):
|
||||
|
||||
"""For LLMs which support explicit tool calling"""
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, message: PreviousMessage) -> str:
|
||||
# TODO improve / iterate
|
||||
|
||||
return f'I searched for some things! """thigns that I searched for!: {message.message}"""'
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return self._tool_definition
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolRegistry
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -31,7 +32,7 @@ YES_IMAGE_GENERATION = "Yes Image Generation"
|
||||
SKIP_IMAGE_GENERATION = "Skip Image Generation"
|
||||
|
||||
IMAGE_GENERATION_TEMPLATE = f"""
|
||||
Given the conversation history and a follow up query, determine if the system should call \
|
||||
Given the conversation history and a follow up query, determine if the system should call
|
||||
an external image generation tool to better answer the latest user input.
|
||||
Your default response is {SKIP_IMAGE_GENERATION}.
|
||||
|
||||
@@ -62,6 +63,7 @@ class ImageShape(str, Enum):
|
||||
LANDSCAPE = "landscape"
|
||||
|
||||
|
||||
@ToolRegistry.register("run_image_generation")
|
||||
class ImageGenerationTool(Tool):
|
||||
_NAME = "run_image_generation"
|
||||
_DESCRIPTION = "Generate an image from a prompt."
|
||||
@@ -121,6 +123,11 @@ class ImageGenerationTool(Tool):
|
||||
},
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, message: PreviousMessage) -> str:
|
||||
# TODO improve / iterate
|
||||
return f'I generated images with these descriptions! """Descriptions: {message.message}"""'
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
@@ -210,7 +217,8 @@ class ImageGenerationTool(Tool):
|
||||
in error_message
|
||||
):
|
||||
raise ValueError(
|
||||
"The image generation request was rejected due to OpenAI's content policy. Please try a different prompt."
|
||||
"The image generation request was rejected due to OpenAI's content policy."
|
||||
+ "Please try a different prompt."
|
||||
)
|
||||
elif "Invalid image URL" in error_message:
|
||||
raise ValueError("Invalid image URL provided for image generation.")
|
||||
|
||||
@@ -6,7 +6,7 @@ from danswer.llm.utils import build_content_with_imgs
|
||||
IMG_GENERATION_SUMMARY_PROMPT = """
|
||||
You have just created the attached images in response to the following query: "{query}".
|
||||
|
||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
Can you please summarize them in a sentence or two? NEVER include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@@ -119,6 +119,12 @@ class InternetSearchTool(Tool):
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, message: PreviousMessage) -> str:
|
||||
# TODO improve / iterate
|
||||
|
||||
return f'I searched for some things! """thigns that I searched for!: {message.message}"""'
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
|
||||
@@ -35,5 +35,5 @@ class ToolRunnerResponse(BaseModel):
|
||||
return values
|
||||
|
||||
|
||||
class ToolCallFinalResult(ToolCallKickoff):
|
||||
class ToolCallMetadata(ToolCallKickoff):
|
||||
tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature
|
||||
|
||||
@@ -32,6 +32,7 @@ from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.tools.search.search_utils import llm_doc_to_dict
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolRegistry
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -65,6 +66,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
|
||||
"""
|
||||
|
||||
|
||||
@ToolRegistry.register(SEARCH_RESPONSE_SUMMARY_ID)
|
||||
class SearchTool(Tool):
|
||||
_NAME = "run_search"
|
||||
_DISPLAY_NAME = "Search Tool"
|
||||
@@ -175,6 +177,12 @@ class SearchTool(Tool):
|
||||
)
|
||||
return {"query": rephrased_query}
|
||||
|
||||
@classmethod
|
||||
def create_prompt(cls, message: PreviousMessage) -> str:
|
||||
# TODO improve / iterate
|
||||
|
||||
return f'I searched for some things! """thigns that I searched for!: {message.message}"""'
|
||||
|
||||
"""Actual tool execution"""
|
||||
|
||||
def _build_response_for_specified_sections(
|
||||
|
||||
@@ -1,12 +1,18 @@
|
||||
import abc
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Type
|
||||
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
# from danswer.llm.answering.models import ChatMessage
|
||||
|
||||
|
||||
class Tool(abc.ABC):
|
||||
@property
|
||||
@@ -24,6 +30,11 @@ class Tool(abc.ABC):
|
||||
def display_name(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def create_prompt(self, message: PreviousMessage) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
"""For LLMs which support explicit tool calling"""
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -61,3 +72,29 @@ class Tool(abc.ABC):
|
||||
It is the result that will be stored in the database.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ToolRegistry:
|
||||
_registry: Dict[str, Type[Tool]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(cls, tool_id: str) -> Callable[[type[Tool]], type[Tool]]:
|
||||
def decorator(tool_class: Type[Tool]) -> type[Tool]:
|
||||
cls._registry[tool_id] = tool_class
|
||||
return tool_class
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_tool(cls, tool_id: str) -> Type[Tool]:
|
||||
if tool_id not in cls._registry:
|
||||
raise ValueError(f"No tool registered with id: {tool_id}")
|
||||
return cls._registry[tool_id]
|
||||
|
||||
@classmethod
|
||||
def get_prompt(cls, tool_id: str, message: PreviousMessage) -> str:
|
||||
if tool_id not in cls._registry:
|
||||
raise ValueError(f"No tool registered with id: {tool_id}")
|
||||
tool = cls._registry[tool_id]
|
||||
new_prompt = tool.create_prompt(message=cast(PreviousMessage, message))
|
||||
return new_prompt
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Any
|
||||
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolCallMetadata
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -36,8 +36,8 @@ class ToolRunner:
|
||||
tool_responses = list(self.tool_responses())
|
||||
return self.tool.build_tool_message_content(*tool_responses)
|
||||
|
||||
def tool_final_result(self) -> ToolCallFinalResult:
|
||||
return ToolCallFinalResult(
|
||||
def tool_final_result(self) -> ToolCallMetadata:
|
||||
return ToolCallMetadata(
|
||||
tool_name=self.tool.name,
|
||||
tool_args=self.args,
|
||||
tool_result=self.tool.final_result(*self.tool_responses()),
|
||||
|
||||
@@ -23,7 +23,6 @@ from ee.danswer.server.enterprise_settings.store import (
|
||||
)
|
||||
from ee.danswer.server.enterprise_settings.store import upload_logo
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
|
||||
@@ -36,7 +35,8 @@ class SeedConfiguration(BaseModel):
|
||||
personas: list[CreatePersonaRequest] | None = None
|
||||
settings: Settings | None = None
|
||||
enterprise_settings: EnterpriseSettings | None = None
|
||||
analytics_script: AnalyticsScriptUpload | None = None
|
||||
analytics_script_key: str | None = None
|
||||
analytics_script_path: str | None = None
|
||||
|
||||
|
||||
def _parse_env() -> SeedConfiguration | None:
|
||||
@@ -119,10 +119,19 @@ def _seed_logo(db_session: Session, logo_path: str | None) -> None:
|
||||
|
||||
|
||||
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
|
||||
if seed_config.analytics_script is not None:
|
||||
if seed_config.analytics_script_path and seed_config.analytics_script_key:
|
||||
logger.info("Seeding analytics script")
|
||||
try:
|
||||
store_analytics_script(seed_config.analytics_script)
|
||||
with open(seed_config.analytics_script_path, "r") as file:
|
||||
script_content = file.read()
|
||||
analytics_script = AnalyticsScriptUpload(
|
||||
script=script_content, secret_key=seed_config.analytics_script_key
|
||||
)
|
||||
store_analytics_script(analytics_script)
|
||||
except FileNotFoundError:
|
||||
logger.error(
|
||||
f"Analytics script file not found: {seed_config.analytics_script_path}"
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.error(f"Failed to seed analytics script: {str(e)}")
|
||||
|
||||
@@ -133,7 +142,6 @@ def get_seed_config() -> SeedConfiguration | None:
|
||||
|
||||
def seed_db() -> None:
|
||||
seed_config = _parse_env()
|
||||
|
||||
if seed_config is None:
|
||||
logger.info("No seeding configuration file passed")
|
||||
return
|
||||
|
||||
@@ -49,6 +49,12 @@ ENV NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PRED
|
||||
ARG NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN
|
||||
ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
|
||||
|
||||
ARG NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH
|
||||
ENV NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH=${NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH}
|
||||
|
||||
|
||||
|
||||
|
||||
ARG NEXT_PUBLIC_THEME
|
||||
ENV NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME}
|
||||
|
||||
@@ -112,6 +118,9 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T
|
||||
ARG NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN
|
||||
ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
|
||||
|
||||
ARG NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH
|
||||
ENV NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH=${NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH}
|
||||
|
||||
ARG NEXT_PUBLIC_DISABLE_LOGOUT
|
||||
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
||||
|
||||
|
||||
@@ -14,26 +14,6 @@ const AddPromptSchema = Yup.object().shape({
|
||||
});
|
||||
|
||||
const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => {
|
||||
const defaultPrompts = [
|
||||
{
|
||||
title: "Email help",
|
||||
prompt: "Write a professional email addressing the following points:",
|
||||
},
|
||||
{
|
||||
title: "Code explanation",
|
||||
prompt: "Explain the following code snippet in simple terms:",
|
||||
},
|
||||
{
|
||||
title: "Product description",
|
||||
prompt: "Write a compelling product description for the following item:",
|
||||
},
|
||||
{
|
||||
title: "Troubleshooting steps",
|
||||
prompt:
|
||||
"Provide step-by-step troubleshooting instructions for the following issue:",
|
||||
},
|
||||
];
|
||||
|
||||
return (
|
||||
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
|
||||
<Formik
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
186
web/src/app/chat/RegenerateOption.tsx
Normal file
186
web/src/app/chat/RegenerateOption.tsx
Normal file
@@ -0,0 +1,186 @@
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import {
|
||||
getDisplayNameForModel,
|
||||
LlmOverride,
|
||||
useLlmOverride,
|
||||
} from "@/lib/hooks";
|
||||
import {
|
||||
DefaultDropdownElement,
|
||||
StringOrNumberOption,
|
||||
} from "@/components/Dropdown";
|
||||
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
|
||||
import { useState } from "react";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { Popover } from "@/components/popover/Popover";
|
||||
import { FiStar } from "react-icons/fi";
|
||||
import { StarFeedback } from "@/components/icons/icons";
|
||||
import { IconType } from "react-icons";
|
||||
|
||||
export function RegenerateDropdown({
|
||||
options,
|
||||
selected,
|
||||
onSelect,
|
||||
side,
|
||||
maxHeight,
|
||||
gptBox,
|
||||
alternate,
|
||||
}: {
|
||||
alternate?: string;
|
||||
options: StringOrNumberOption[];
|
||||
selected: string | null;
|
||||
onSelect: (value: string | number | null) => void;
|
||||
includeDefault?: boolean;
|
||||
side?: "top" | "right" | "bottom" | "left";
|
||||
maxHeight?: string;
|
||||
gptBox?: boolean;
|
||||
}) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
const Dropdown = (
|
||||
<div
|
||||
className={`
|
||||
border
|
||||
border
|
||||
rounded-lg
|
||||
flex
|
||||
flex-col
|
||||
mx-2
|
||||
bg-background
|
||||
${maxHeight || "max-h-96"}
|
||||
overflow-y-auto
|
||||
overscroll-contain relative`}
|
||||
>
|
||||
<p
|
||||
className="
|
||||
sticky
|
||||
top-0
|
||||
flex
|
||||
bg-background
|
||||
font-bold
|
||||
px-3
|
||||
text-sm
|
||||
py-1.5
|
||||
"
|
||||
>
|
||||
Pick a model
|
||||
</p>
|
||||
{options.map((option, ind) => {
|
||||
const isSelected = option.value === selected;
|
||||
return (
|
||||
<DefaultDropdownElement
|
||||
key={option.value}
|
||||
name={getDisplayNameForModel(option.name)}
|
||||
description={option.description}
|
||||
onSelect={() => onSelect(option.value)}
|
||||
isSelected={isSelected}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<Popover
|
||||
open={isOpen}
|
||||
onOpenChange={(open) => setIsOpen(open)}
|
||||
content={
|
||||
<div onClick={() => setIsOpen(!isOpen)}>
|
||||
{!alternate ? (
|
||||
<Hoverable size={16} icon={StarFeedback as IconType} />
|
||||
) : (
|
||||
<Hoverable
|
||||
size={16}
|
||||
icon={StarFeedback as IconType}
|
||||
hoverText={getDisplayNameForModel(alternate)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
popover={Dropdown}
|
||||
align="start"
|
||||
side={side}
|
||||
sideOffset={5}
|
||||
triggerMaxWidth
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default function RegenerateOption({
|
||||
selectedAssistant,
|
||||
regenerate,
|
||||
alternateModel,
|
||||
onHoverChange,
|
||||
}: {
|
||||
selectedAssistant: Persona;
|
||||
regenerate: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
alternateModel?: string;
|
||||
onHoverChange: (isHovered: boolean) => void;
|
||||
}) {
|
||||
const llmOverrideManager = useLlmOverride();
|
||||
|
||||
const { llmProviders } = useChatContext();
|
||||
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
|
||||
|
||||
const llmOptionsByProvider: {
|
||||
[provider: string]: { name: string; value: string }[];
|
||||
} = {};
|
||||
const uniqueModelNames = new Set<string>();
|
||||
|
||||
llmProviders.forEach((llmProvider) => {
|
||||
if (!llmOptionsByProvider[llmProvider.provider]) {
|
||||
llmOptionsByProvider[llmProvider.provider] = [];
|
||||
}
|
||||
|
||||
(llmProvider.display_model_names || llmProvider.model_names).forEach(
|
||||
(modelName) => {
|
||||
if (!uniqueModelNames.has(modelName)) {
|
||||
uniqueModelNames.add(modelName);
|
||||
llmOptionsByProvider[llmProvider.provider].push({
|
||||
name: modelName,
|
||||
value: structureValue(
|
||||
llmProvider.name,
|
||||
llmProvider.provider,
|
||||
modelName
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
|
||||
([provider, options]) => [...options]
|
||||
);
|
||||
|
||||
const currentModelName =
|
||||
llmOverrideManager?.llmOverride.modelName ||
|
||||
(selectedAssistant
|
||||
? selectedAssistant.llm_model_version_override || llmName
|
||||
: llmName);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="group flex items-center relative"
|
||||
onMouseEnter={() => onHoverChange(true)}
|
||||
onMouseLeave={() => onHoverChange(false)}
|
||||
>
|
||||
<RegenerateDropdown
|
||||
alternate={alternateModel}
|
||||
options={llmOptions}
|
||||
selected={currentModelName}
|
||||
onSelect={(value) => {
|
||||
const { name, provider, modelName } = destructureValue(
|
||||
value as string
|
||||
);
|
||||
regenerate({
|
||||
name: name,
|
||||
provider: provider,
|
||||
modelName: modelName,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
buildDocumentSummaryDisplay,
|
||||
} from "@/components/search/DocumentDisplay";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { MOBILE_SIDEBAR_CARD_WIDTH, SIDEBAR_CARD_WIDTH } from "@/lib/constants";
|
||||
|
||||
interface DocumentDisplayProps {
|
||||
document: DanswerDocument;
|
||||
@@ -39,8 +40,8 @@ export function ChatDocumentDisplay({
|
||||
return (
|
||||
<div
|
||||
key={document.semantic_identifier}
|
||||
className={`p-2 w-[325px] justify-start rounded-md ${
|
||||
isSelected ? "bg-background-200" : "bg-background-125"
|
||||
className={`p-2 desktop:${SIDEBAR_CARD_WIDTH} mobile:${MOBILE_SIDEBAR_CARD_WIDTH} justify-start rounded-md ${
|
||||
isSelected ? "bg-background-200" : "bg-background-150"
|
||||
} text-sm mx-3`}
|
||||
>
|
||||
<div className="flex relative justify-start overflow-y-visible">
|
||||
|
||||
@@ -3,12 +3,12 @@ import { Divider, Text } from "@tremor/react";
|
||||
import { ChatDocumentDisplay } from "./ChatDocumentDisplay";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { removeDuplicateDocs } from "@/lib/documentUtils";
|
||||
import { Message, RetrievalType } from "../interfaces";
|
||||
import { Message } from "../interfaces";
|
||||
import { ForwardedRef, forwardRef } from "react";
|
||||
|
||||
interface DocumentSidebarProps {
|
||||
closeSidebar: () => void;
|
||||
selectedMessage: Message | null;
|
||||
currentDocuments: DanswerDocument[] | null;
|
||||
selectedDocuments: DanswerDocument[] | null;
|
||||
toggleDocumentSelection: (document: DanswerDocument) => void;
|
||||
clearSelectedDocuments: () => void;
|
||||
@@ -23,7 +23,7 @@ export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
|
||||
(
|
||||
{
|
||||
closeSidebar,
|
||||
selectedMessage,
|
||||
currentDocuments,
|
||||
selectedDocuments,
|
||||
toggleDocumentSelection,
|
||||
clearSelectedDocuments,
|
||||
@@ -40,7 +40,6 @@ export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
|
||||
const selectedDocumentIds =
|
||||
selectedDocuments?.map((document) => document.document_id) || [];
|
||||
|
||||
const currentDocuments = selectedMessage?.documents || null;
|
||||
const dedupedDocuments = removeDuplicateDocs(currentDocuments || []);
|
||||
|
||||
// NOTE: do not allow selection if less than 75 tokens are left
|
||||
@@ -72,14 +71,8 @@ export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
|
||||
{popup}
|
||||
<div className="pl-3 mx-2 pr-6 mt-3 flex text-text-800 flex-col text-2xl text-emphasis flex font-semibold">
|
||||
{dedupedDocuments.length} Documents
|
||||
<p className="text-sm font-semibold flex flex-wrap gap-x-2 text-text-600 mt-1">
|
||||
<p className="text-sm flex flex-wrap gap-x-2 font-normal text-text-700 mt-1">
|
||||
Select to add to continuous context
|
||||
<a
|
||||
href="https://docs.danswer.dev/introduction"
|
||||
className="underline cursor-pointer hover:text-strong"
|
||||
>
|
||||
Learn more
|
||||
</a>
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -137,14 +130,14 @@ export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
|
||||
<div className="absolute left-0 bottom-0 w-full bg-gradient-to-b from-neutral-100/0 via-neutral-100/40 backdrop-blur-xs to-neutral-100 h-[100px]" />
|
||||
<div className="sticky bottom-4 w-full left-0 justify-center flex gap-x-4">
|
||||
<button
|
||||
className="bg-[#84e49e] text-xs p-2 rounded text-text-800"
|
||||
className="bg-background-800 px-3 hover:bg-background-600 transition-background duration-300 py-2.5 scale-[.95] rounded text-text-200"
|
||||
onClick={() => closeSidebar()}
|
||||
>
|
||||
Save Changes
|
||||
</button>
|
||||
|
||||
<button
|
||||
className="bg-error text-xs p-2 rounded text-text-200"
|
||||
className="bg-background-125 hover:bg-background-150 transition-background duration-300 ring ring-1 ring-border scale-[.95] px-3 py-2.5 rounded text-text-900"
|
||||
onClick={() => {
|
||||
clearSelectedDocuments();
|
||||
|
||||
|
||||
@@ -41,7 +41,7 @@ export function FullImageModal({
|
||||
<img
|
||||
src={buildImgUrl(fileId)}
|
||||
alt="Uploaded image"
|
||||
className="max-w-full max-h-full"
|
||||
className="max-w-full rounded-lg max-h-full"
|
||||
/>
|
||||
</Dialog.Content>
|
||||
</Dialog.Portal>
|
||||
|
||||
@@ -15,10 +15,11 @@ export function InMessageImage({ fileId }: { fileId: string }) {
|
||||
/>
|
||||
|
||||
<div className="relative w-full h-full max-w-96 max-h-96">
|
||||
{!imageLoaded && (
|
||||
<div className="absolute inset-0 bg-gray-200 animate-pulse rounded-lg" />
|
||||
)}
|
||||
|
||||
<div
|
||||
className={`absolute inset-0 bg-gray-200 rounded-lg transition-opacity duration-300 ${
|
||||
imageLoaded ? "opacity-0" : "opacity-100"
|
||||
}`}
|
||||
/>
|
||||
<img
|
||||
width={1200}
|
||||
height={1200}
|
||||
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
CpuIconSkeleton,
|
||||
FileIcon,
|
||||
SendIcon,
|
||||
StopGeneratingIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { IconType } from "react-icons";
|
||||
import Popup from "../../../components/popup/Popup";
|
||||
@@ -31,6 +32,9 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { StopCircle } from "@phosphor-icons/react/dist/ssr";
|
||||
import { Square } from "@phosphor-icons/react";
|
||||
import { ChatState } from "../types";
|
||||
const MAX_INPUT_HEIGHT = 200;
|
||||
|
||||
export function ChatInputBar({
|
||||
@@ -39,10 +43,11 @@ export function ChatInputBar({
|
||||
selectedDocuments,
|
||||
message,
|
||||
setMessage,
|
||||
stopGenerating,
|
||||
onSubmit,
|
||||
isStreaming,
|
||||
filterManager,
|
||||
llmOverrideManager,
|
||||
chatState,
|
||||
|
||||
// assistants
|
||||
selectedAssistant,
|
||||
@@ -59,6 +64,8 @@ export function ChatInputBar({
|
||||
inputPrompts,
|
||||
}: {
|
||||
openModelSettings: () => void;
|
||||
chatState: ChatState;
|
||||
stopGenerating: () => void;
|
||||
showDocs: () => void;
|
||||
selectedDocuments: DanswerDocument[];
|
||||
assistantOptions: Persona[];
|
||||
@@ -68,7 +75,6 @@ export function ChatInputBar({
|
||||
message: string;
|
||||
setMessage: (message: string) => void;
|
||||
onSubmit: () => void;
|
||||
isStreaming: boolean;
|
||||
filterManager: FilterManager;
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
selectedAssistant: Persona;
|
||||
@@ -597,24 +603,38 @@ export function ChatInputBar({
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="absolute bottom-2.5 mobile:right-4 desktop:right-10">
|
||||
<div
|
||||
className="cursor-pointer"
|
||||
onClick={() => {
|
||||
if (message) {
|
||||
onSubmit();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<SendIcon
|
||||
size={28}
|
||||
className={`text-emphasis text-white p-1 rounded-full ${
|
||||
message && !isStreaming
|
||||
? "bg-background-800"
|
||||
: "bg-[#D7D7D7]"
|
||||
}`}
|
||||
/>
|
||||
</div>
|
||||
{chatState == "streaming" ||
|
||||
chatState == "toolBuilding" ||
|
||||
chatState == "loading" ? (
|
||||
<button
|
||||
className={`cursor-pointer ${chatState != "streaming" ? "bg-background-400" : "bg-background-800"} h-[28px] w-[28px] rounded-full`}
|
||||
onClick={stopGenerating}
|
||||
disabled={chatState != "streaming"}
|
||||
>
|
||||
<StopGeneratingIcon
|
||||
size={10}
|
||||
className={`text-emphasis m-auto text-white flex-none
|
||||
}`}
|
||||
/>
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
className="cursor-pointer"
|
||||
onClick={() => {
|
||||
if (message) {
|
||||
onSubmit();
|
||||
}
|
||||
}}
|
||||
disabled={chatState != "input"}
|
||||
>
|
||||
<SendIcon
|
||||
size={28}
|
||||
className={`text-emphasis text-white p-1 rounded-full ${chatState == "input" && message ? "bg-background-800" : "bg-background-400"} `}
|
||||
/>
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -49,12 +49,6 @@ export interface ToolCallMetadata {
|
||||
tool_result?: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface ToolCallFinalResult {
|
||||
tool_name: string;
|
||||
tool_args: Record<string, any>;
|
||||
tool_result: Record<string, any>;
|
||||
}
|
||||
|
||||
export interface ChatSession {
|
||||
id: number;
|
||||
name: string;
|
||||
@@ -72,6 +66,24 @@ export interface SearchSession {
|
||||
description: string;
|
||||
}
|
||||
|
||||
export interface PreviousAIMessage {
|
||||
messageId?: number;
|
||||
message?: string;
|
||||
type?: "assistant";
|
||||
retrievalType?: RetrievalType;
|
||||
query?: string | null;
|
||||
documents?: DanswerDocument[] | null;
|
||||
citations?: CitationMap;
|
||||
files?: FileDescriptor[];
|
||||
toolCall?: ToolCallMetadata | null;
|
||||
|
||||
// for rebuilding the message tree
|
||||
parentMessageId?: number | null;
|
||||
childrenMessageIds?: number[];
|
||||
latestChildMessageId?: number | null;
|
||||
alternateAssistantID?: number | null;
|
||||
}
|
||||
|
||||
export interface Message {
|
||||
messageId: number;
|
||||
message: string;
|
||||
@@ -81,13 +93,15 @@ export interface Message {
|
||||
documents?: DanswerDocument[] | null;
|
||||
citations?: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
toolCalls: ToolCallMetadata[];
|
||||
toolCall: ToolCallMetadata | null;
|
||||
|
||||
// for rebuilding the message tree
|
||||
parentMessageId: number | null;
|
||||
childrenMessageIds?: number[];
|
||||
latestChildMessageId?: number | null;
|
||||
alternateAssistantID?: number | null;
|
||||
stackTrace?: string | null;
|
||||
alternate_model?: string;
|
||||
}
|
||||
|
||||
export interface BackendChatSession {
|
||||
@@ -114,8 +128,14 @@ export interface BackendMessage {
|
||||
time_sent: string;
|
||||
citations: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
tool_calls: ToolCallFinalResult[];
|
||||
tool_call: ToolCallMetadata | null;
|
||||
alternate_assistant_id?: number | null;
|
||||
alternate_model?: string;
|
||||
}
|
||||
|
||||
export interface MessageResponseIDInfo {
|
||||
user_message_id: number | null;
|
||||
reserved_assistant_message_id: number;
|
||||
}
|
||||
|
||||
export interface DocumentsResponse {
|
||||
|
||||
@@ -3,8 +3,8 @@ import {
|
||||
DanswerDocument,
|
||||
Filters,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { handleStream } from "@/lib/search/streamingUtils";
|
||||
import { FeedbackType } from "./types";
|
||||
import { handleSSEStream, handleStream } from "@/lib/search/streamingUtils";
|
||||
import { ChatState, FeedbackType } from "./types";
|
||||
import {
|
||||
Dispatch,
|
||||
MutableRefObject,
|
||||
@@ -20,6 +20,7 @@ import {
|
||||
FileDescriptor,
|
||||
ImageGenerationDisplay,
|
||||
Message,
|
||||
MessageResponseIDInfo,
|
||||
RetrievalType,
|
||||
StreamingError,
|
||||
ToolCallMetadata,
|
||||
@@ -109,9 +110,11 @@ export type PacketType =
|
||||
| AnswerPiecePacket
|
||||
| DocumentsResponse
|
||||
| ImageGenerationDisplay
|
||||
| StreamingError;
|
||||
| StreamingError
|
||||
| MessageResponseIDInfo;
|
||||
|
||||
export async function* sendMessage({
|
||||
regenerate,
|
||||
message,
|
||||
fileDescriptors,
|
||||
parentMessageId,
|
||||
@@ -127,7 +130,9 @@ export async function* sendMessage({
|
||||
systemPromptOverride,
|
||||
useExistingUserMessage,
|
||||
alternateAssistantId,
|
||||
signal,
|
||||
}: {
|
||||
regenerate: boolean;
|
||||
message: string;
|
||||
fileDescriptors: FileDescriptor[];
|
||||
parentMessageId: number | null;
|
||||
@@ -137,70 +142,70 @@ export async function* sendMessage({
|
||||
selectedDocumentIds: number[] | null;
|
||||
queryOverride?: string;
|
||||
forceSearch?: boolean;
|
||||
// LLM overrides
|
||||
modelProvider?: string;
|
||||
modelVersion?: string;
|
||||
temperature?: number;
|
||||
// prompt overrides
|
||||
systemPromptOverride?: string;
|
||||
// if specified, will use the existing latest user message
|
||||
// and will ignore the specified `message`
|
||||
useExistingUserMessage?: boolean;
|
||||
alternateAssistantId?: number;
|
||||
}) {
|
||||
signal?: AbortSignal;
|
||||
}): AsyncGenerator<PacketType, void, unknown> {
|
||||
const documentsAreSelected =
|
||||
selectedDocumentIds && selectedDocumentIds.length > 0;
|
||||
|
||||
const sendMessageResponse = await fetch("/api/chat/send-message", {
|
||||
const body = JSON.stringify({
|
||||
alternate_assistant_id: alternateAssistantId,
|
||||
chat_session_id: chatSessionId,
|
||||
parent_message_id: parentMessageId,
|
||||
message: message,
|
||||
prompt_id: promptId,
|
||||
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
|
||||
file_descriptors: fileDescriptors,
|
||||
regenerate,
|
||||
retrieval_options: !documentsAreSelected
|
||||
? {
|
||||
run_search:
|
||||
promptId === null ||
|
||||
promptId === undefined ||
|
||||
queryOverride ||
|
||||
forceSearch
|
||||
? "always"
|
||||
: "auto",
|
||||
real_time: true,
|
||||
filters: filters,
|
||||
}
|
||||
: null,
|
||||
query_override: queryOverride,
|
||||
prompt_override: systemPromptOverride
|
||||
? {
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
: null,
|
||||
llm_override:
|
||||
temperature || modelVersion
|
||||
? {
|
||||
temperature,
|
||||
model_provider: modelProvider,
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
});
|
||||
|
||||
const response = await fetch(`/api/chat/send-message`, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
alternate_assistant_id: alternateAssistantId,
|
||||
chat_session_id: chatSessionId,
|
||||
parent_message_id: parentMessageId,
|
||||
message: message,
|
||||
prompt_id: promptId,
|
||||
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
|
||||
file_descriptors: fileDescriptors,
|
||||
retrieval_options: !documentsAreSelected
|
||||
? {
|
||||
run_search:
|
||||
promptId === null ||
|
||||
promptId === undefined ||
|
||||
queryOverride ||
|
||||
forceSearch
|
||||
? "always"
|
||||
: "auto",
|
||||
real_time: true,
|
||||
filters: filters,
|
||||
}
|
||||
: null,
|
||||
query_override: queryOverride,
|
||||
prompt_override: systemPromptOverride
|
||||
? {
|
||||
system_prompt: systemPromptOverride,
|
||||
}
|
||||
: null,
|
||||
llm_override:
|
||||
temperature || modelVersion
|
||||
? {
|
||||
temperature,
|
||||
model_provider: modelProvider,
|
||||
model_version: modelVersion,
|
||||
}
|
||||
: null,
|
||||
use_existing_user_message: useExistingUserMessage,
|
||||
}),
|
||||
body,
|
||||
signal,
|
||||
});
|
||||
if (!sendMessageResponse.ok) {
|
||||
const errorJson = await sendMessageResponse.json();
|
||||
const errorMsg = errorJson.message || errorJson.detail || "";
|
||||
throw Error(`Failed to send message - ${errorMsg}`);
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
yield* handleStream<PacketType>(sendMessageResponse);
|
||||
yield* handleSSEStream<PacketType>(response);
|
||||
}
|
||||
|
||||
export async function nameChatSession(chatSessionId: number, message: string) {
|
||||
@@ -384,13 +389,12 @@ export function getLastSuccessfulMessageId(messageHistory: Message[]) {
|
||||
.reverse()
|
||||
.find(
|
||||
(message) =>
|
||||
message.type === "assistant" &&
|
||||
(message.type === "assistant" || message.type === "system") &&
|
||||
message.messageId !== -1 &&
|
||||
message.messageId !== null
|
||||
);
|
||||
return lastSuccessfulMessage ? lastSuccessfulMessage?.messageId : null;
|
||||
}
|
||||
|
||||
export function processRawChatHistory(
|
||||
rawMessages: BackendMessage[]
|
||||
): Map<number, Message> {
|
||||
@@ -429,10 +433,11 @@ export function processRawChatHistory(
|
||||
citations: messageInfo?.citations || {},
|
||||
}
|
||||
: {}),
|
||||
toolCalls: messageInfo.tool_calls,
|
||||
toolCall: messageInfo.tool_call,
|
||||
parentMessageId: messageInfo.parent_message,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: messageInfo.latest_child_message,
|
||||
alternate_model: messageInfo.alternate_model,
|
||||
};
|
||||
|
||||
messages.set(messageInfo.message_id, message);
|
||||
@@ -635,14 +640,14 @@ export async function uploadFilesForChat(
|
||||
}
|
||||
|
||||
export async function useScrollonStream({
|
||||
isStreaming,
|
||||
chatState,
|
||||
scrollableDivRef,
|
||||
scrollDist,
|
||||
endDivRef,
|
||||
distance,
|
||||
debounce,
|
||||
}: {
|
||||
isStreaming: boolean;
|
||||
chatState: ChatState;
|
||||
scrollableDivRef: RefObject<HTMLDivElement>;
|
||||
scrollDist: MutableRefObject<number>;
|
||||
endDivRef: RefObject<HTMLDivElement>;
|
||||
@@ -656,7 +661,7 @@ export async function useScrollonStream({
|
||||
const previousScroll = useRef<number>(0);
|
||||
|
||||
useEffect(() => {
|
||||
if (isStreaming && scrollableDivRef && scrollableDivRef.current) {
|
||||
if (chatState != "input" && scrollableDivRef && scrollableDivRef.current) {
|
||||
let newHeight: number = scrollableDivRef.current?.scrollTop!;
|
||||
const heightDifference = newHeight - previousScroll.current;
|
||||
previousScroll.current = newHeight;
|
||||
@@ -712,7 +717,7 @@ export async function useScrollonStream({
|
||||
|
||||
// scroll on end of stream if within distance
|
||||
useEffect(() => {
|
||||
if (scrollableDivRef?.current && !isStreaming) {
|
||||
if (scrollableDivRef?.current && chatState == "input") {
|
||||
if (scrollDist.current < distance - 50) {
|
||||
scrollableDivRef?.current?.scrollBy({
|
||||
left: 0,
|
||||
@@ -721,5 +726,5 @@ export async function useScrollonStream({
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [isStreaming]);
|
||||
}, [chatState]);
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
useState,
|
||||
} from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import Prism from "prismjs";
|
||||
import {
|
||||
DanswerDocument,
|
||||
FilteredDanswerDocument,
|
||||
@@ -44,22 +45,29 @@ import "./custom-code-styles.css";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
|
||||
import {
|
||||
DislikeFeedbackIcon,
|
||||
LikeFeedbackIcon,
|
||||
ThumbsUpIcon,
|
||||
ThumbsDownIcon,
|
||||
LikeFeedback,
|
||||
DislikeFeedback,
|
||||
ToolCallIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import {
|
||||
CustomTooltip,
|
||||
TooltipGroup,
|
||||
} from "@/components/tooltip/CustomTooltip";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { useMouseTracking } from "./hooks";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import DualPromptDisplay from "../tools/ImagePromptCitation";
|
||||
import { Popover } from "@/components/popover/Popover";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
|
||||
import RegenerateOption from "../RegenerateOption";
|
||||
import { LlmOverride } from "@/lib/hooks";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -110,9 +118,13 @@ function FileDisplay({
|
||||
}
|
||||
|
||||
export const AIMessage = ({
|
||||
regenerate,
|
||||
alternateModel,
|
||||
shared,
|
||||
isActive,
|
||||
toggleDocumentSelection,
|
||||
hasParentAI,
|
||||
hasChildAI,
|
||||
alternativeAssistant,
|
||||
docs,
|
||||
messageId,
|
||||
@@ -124,6 +136,7 @@ export const AIMessage = ({
|
||||
citedDocuments,
|
||||
toolCall,
|
||||
isComplete,
|
||||
generatingTool,
|
||||
hasDocs,
|
||||
handleFeedback,
|
||||
isCurrentlyShowingRetrieved,
|
||||
@@ -132,9 +145,16 @@ export const AIMessage = ({
|
||||
handleForceSearch,
|
||||
retrievalDisabled,
|
||||
currentPersona,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
setPopup,
|
||||
}: {
|
||||
shared?: boolean;
|
||||
hasChildAI?: boolean;
|
||||
hasParentAI?: boolean;
|
||||
isActive?: boolean;
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
onMessageSelection?: (messageId: number) => void;
|
||||
selectedDocuments?: DanswerDocument[] | null;
|
||||
toggleDocumentSelection?: () => void;
|
||||
docs?: DanswerDocument[] | null;
|
||||
@@ -148,6 +168,7 @@ export const AIMessage = ({
|
||||
citedDocuments?: [string, DanswerDocument][] | null;
|
||||
toolCall?: ToolCallMetadata;
|
||||
isComplete?: boolean;
|
||||
generatingTool?: boolean;
|
||||
hasDocs?: boolean;
|
||||
handleFeedback?: (feedbackType: FeedbackType) => void;
|
||||
isCurrentlyShowingRetrieved?: boolean;
|
||||
@@ -155,9 +176,13 @@ export const AIMessage = ({
|
||||
handleSearchQueryEdit?: (query: string) => void;
|
||||
handleForceSearch?: () => void;
|
||||
retrievalDisabled?: boolean;
|
||||
alternateModel?: string;
|
||||
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
}) => {
|
||||
const toolCallGenerating = toolCall && !toolCall.tool_result;
|
||||
const processContent = (content: string | JSX.Element) => {
|
||||
|
||||
const buildFinalContentDisplay = (content: string | JSX.Element) => {
|
||||
if (typeof content !== "string") {
|
||||
return content;
|
||||
}
|
||||
@@ -179,10 +204,41 @@ export const AIMessage = ({
|
||||
}
|
||||
}
|
||||
|
||||
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
|
||||
};
|
||||
const finalContent = processContent(content as string);
|
||||
const indicator = !isComplete && !toolCallGenerating ? " [*]()" : "";
|
||||
const tool_citation =
|
||||
isComplete && toolCall?.tool_result ? ` [[${toolCall.tool_name}]]()` : "";
|
||||
|
||||
return content + indicator + tool_citation;
|
||||
};
|
||||
|
||||
const finalContent = buildFinalContentDisplay(content as string);
|
||||
const citationRef = useRef<HTMLDivElement | null>(null);
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
citationRef.current &&
|
||||
!citationRef.current.contains(event.target as Node)
|
||||
) {
|
||||
// setIsPopoverOpen(false);
|
||||
}
|
||||
};
|
||||
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
return () => {
|
||||
document.removeEventListener("mousedown", handleClickOutside);
|
||||
};
|
||||
}, []);
|
||||
|
||||
const [isReady, setIsReady] = useState(false);
|
||||
useEffect(() => {
|
||||
Prism.highlightAll();
|
||||
setIsReady(true);
|
||||
}, []);
|
||||
// const finalContent = processContent(content as string);
|
||||
|
||||
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
|
||||
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
|
||||
|
||||
const settings = useContext(SettingsContext);
|
||||
@@ -240,59 +296,39 @@ export const AIMessage = ({
|
||||
});
|
||||
}
|
||||
|
||||
const currentMessageInd = messageId
|
||||
? otherMessagesCanSwitchTo?.indexOf(messageId)
|
||||
: undefined;
|
||||
const uniqueSources: ValidSources[] = Array.from(
|
||||
new Set((docs || []).map((doc) => doc.source_type))
|
||||
).slice(0, 3);
|
||||
|
||||
const includeMessageSwitcher =
|
||||
currentMessageInd !== undefined &&
|
||||
onMessageSelection &&
|
||||
otherMessagesCanSwitchTo &&
|
||||
otherMessagesCanSwitchTo.length > 1;
|
||||
|
||||
return (
|
||||
<div ref={trackedElementRef} className={"py-5 px-2 lg:px-5 relative flex "}>
|
||||
<div
|
||||
ref={trackedElementRef}
|
||||
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
|
||||
>
|
||||
<div
|
||||
className={`mx-auto ${shared ? "w-full" : "w-[90%]"} max-w-message-max`}
|
||||
>
|
||||
<div className={`${!shared && "mobile:ml-4 xl:ml-8"}`}>
|
||||
<div className="flex">
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
|
||||
{!hasParentAI ? (
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
) : (
|
||||
<div className="w-6" />
|
||||
)}
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) &&
|
||||
danswerSearchToolEnabledForPersona && (
|
||||
<>
|
||||
{query !== undefined &&
|
||||
handleShowRetrieved !== undefined &&
|
||||
isCurrentlyShowingRetrieved !== undefined &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SearchSummary
|
||||
query={query}
|
||||
hasDocs={hasDocs || false}
|
||||
messageId={messageId}
|
||||
finished={toolCall?.tool_result != undefined}
|
||||
isCurrentlyShowingRetrieved={
|
||||
isCurrentlyShowingRetrieved
|
||||
}
|
||||
handleShowRetrieved={handleShowRetrieved}
|
||||
handleSearchQueryEdit={handleSearchQueryEdit}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{handleForceSearch &&
|
||||
content &&
|
||||
query === undefined &&
|
||||
!hasDocs &&
|
||||
!retrievalDisabled && (
|
||||
<div className="my-1">
|
||||
<SkippedSearch
|
||||
handleForceSearch={handleForceSearch}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
<div className="w-full ml-4">
|
||||
<div className="max-w-message-max break-words">
|
||||
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
|
||||
@@ -303,19 +339,22 @@ export const AIMessage = ({
|
||||
!retrievalDisabled && (
|
||||
<div className="mb-1">
|
||||
<SearchSummary
|
||||
docs={docs}
|
||||
filteredDocs={filteredDocs}
|
||||
query={query}
|
||||
finished={toolCall?.tool_result != undefined}
|
||||
hasDocs={hasDocs || false}
|
||||
messageId={messageId}
|
||||
isCurrentlyShowingRetrieved={
|
||||
isCurrentlyShowingRetrieved
|
||||
finished={
|
||||
toolCall?.tool_result != undefined ||
|
||||
isComplete!
|
||||
}
|
||||
toggleDocumentSelection={
|
||||
toggleDocumentSelection
|
||||
}
|
||||
handleShowRetrieved={handleShowRetrieved}
|
||||
handleSearchQueryEdit={handleSearchQueryEdit}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{handleForceSearch &&
|
||||
!hasChildAI &&
|
||||
content &&
|
||||
query === undefined &&
|
||||
!hasDocs &&
|
||||
@@ -368,9 +407,8 @@ export const AIMessage = ({
|
||||
{content || files ? (
|
||||
<>
|
||||
<FileDisplay files={files || []} />
|
||||
|
||||
{typeof content === "string" ? (
|
||||
<div className="overflow-x-visible w-full pr-2 max-w-[675px]">
|
||||
<div className="overflow-visible w-full pr-2 max-w-[675px]">
|
||||
<ReactMarkdown
|
||||
key={messageId}
|
||||
className="prose max-w-full"
|
||||
@@ -379,7 +417,49 @@ export const AIMessage = ({
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
if (
|
||||
value?.toString() == `[${SEARCH_TOOL_NAME}]`
|
||||
) {
|
||||
return <></>;
|
||||
} else if (
|
||||
value?.toString() ==
|
||||
`[${IMAGE_GENERATION_TOOL_NAME}]`
|
||||
) {
|
||||
return (
|
||||
<Popover
|
||||
open={isPopoverOpen}
|
||||
onOpenChange={() => null} // only allow closing from the icon
|
||||
content={
|
||||
<button
|
||||
onMouseDown={() => {
|
||||
setIsPopoverOpen(!isPopoverOpen);
|
||||
}}
|
||||
>
|
||||
<ToolCallIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
|
||||
</button>
|
||||
}
|
||||
popover={
|
||||
<DualPromptDisplay
|
||||
arg="Prompt"
|
||||
// ref={citationRef}
|
||||
setPopup={setPopup}
|
||||
prompt1={
|
||||
toolCall?.tool_result?.[0]
|
||||
?.revised_prompt
|
||||
}
|
||||
prompt2={
|
||||
toolCall?.tool_result?.[1]
|
||||
?.revised_prompt
|
||||
}
|
||||
/>
|
||||
}
|
||||
side="top"
|
||||
align="center"
|
||||
/>
|
||||
);
|
||||
} else if (
|
||||
value?.toString().startsWith("*")
|
||||
) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
);
|
||||
@@ -402,7 +482,7 @@ export const AIMessage = ({
|
||||
return (
|
||||
<a
|
||||
key={node?.position?.start?.offset}
|
||||
onMouseDown={() =>
|
||||
onClick={() =>
|
||||
rest.href
|
||||
? window.open(rest.href, "_blank")
|
||||
: undefined
|
||||
@@ -416,7 +496,6 @@ export const AIMessage = ({
|
||||
},
|
||||
code: (props) => (
|
||||
<CodeBlock
|
||||
className="w-full"
|
||||
{...props}
|
||||
content={content as string}
|
||||
/>
|
||||
@@ -440,82 +519,10 @@ export const AIMessage = ({
|
||||
) : isComplete ? null : (
|
||||
<></>
|
||||
)}
|
||||
{isComplete && docs && docs.length > 0 && (
|
||||
<div className="mt-2 -mx-8 w-full mb-4 flex relative">
|
||||
<div className="w-full">
|
||||
<div className="px-8 flex gap-x-2">
|
||||
{!settings?.isMobile &&
|
||||
filteredDocs.length > 0 &&
|
||||
filteredDocs.slice(0, 2).map((doc, ind) => (
|
||||
<div
|
||||
key={doc.document_id}
|
||||
className={`w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 pb-2 pt-1 border-b
|
||||
`}
|
||||
>
|
||||
<a
|
||||
href={doc.link}
|
||||
target="_blank"
|
||||
className="text-sm flex w-full pt-1 gap-x-1.5 overflow-hidden justify-between font-semibold text-text-700"
|
||||
>
|
||||
<Citation link={doc.link} index={ind + 1} />
|
||||
<p className="shrink truncate ellipsis break-all ">
|
||||
{doc.semantic_identifier ||
|
||||
doc.document_id}
|
||||
</p>
|
||||
<div className="ml-auto flex-none">
|
||||
{doc.is_internet ? (
|
||||
<InternetSearchIcon url={doc.link} />
|
||||
) : (
|
||||
<SourceIcon
|
||||
sourceType={doc.source_type}
|
||||
iconSize={18}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</a>
|
||||
<div className="flex overscroll-x-scroll mt-.5">
|
||||
<DocumentMetadataBlock document={doc} />
|
||||
</div>
|
||||
<div className="line-clamp-3 text-xs break-words pt-1">
|
||||
{doc.blurb}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<div
|
||||
onClick={() => {
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={-1}
|
||||
className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b"
|
||||
>
|
||||
<div className="text-sm flex justify-between font-semibold text-text-700">
|
||||
<p className="line-clamp-1">See context</p>
|
||||
<div className="flex gap-x-1">
|
||||
{uniqueSources.map((sourceType, ind) => {
|
||||
return (
|
||||
<div key={ind} className="flex-none">
|
||||
<SourceIcon
|
||||
sourceType={sourceType}
|
||||
iconSize={18}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
<div className="line-clamp-3 text-xs break-words pt-1">
|
||||
See more
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{handleFeedback &&
|
||||
{!hasChildAI &&
|
||||
handleFeedback &&
|
||||
isComplete &&
|
||||
(isActive ? (
|
||||
<div
|
||||
className={`
|
||||
@@ -525,21 +532,53 @@ export const AIMessage = ({
|
||||
`}
|
||||
>
|
||||
<TooltipGroup>
|
||||
<div className="flex justify-start w-full gap-x-0.5">
|
||||
{includeMessageSwitcher && (
|
||||
<div className="-mx-1 mr-auto">
|
||||
<MessageSwitcher
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() => {
|
||||
onMessageSelection(
|
||||
otherMessagesCanSwitchTo[
|
||||
currentMessageInd - 1
|
||||
]
|
||||
);
|
||||
}}
|
||||
handleNext={() => {
|
||||
onMessageSelection(
|
||||
otherMessagesCanSwitchTo[
|
||||
currentMessageInd + 1
|
||||
]
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy!">
|
||||
<CopyButton content={content.toString()} />
|
||||
</CustomTooltip>
|
||||
<CustomTooltip showTick line content="Good response!">
|
||||
<HoverableIcon
|
||||
icon={<LikeFeedbackIcon />}
|
||||
icon={<LikeFeedback />}
|
||||
onClick={() => handleFeedback("like")}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
<CustomTooltip showTick line content="Bad response!">
|
||||
<HoverableIcon
|
||||
icon={<DislikeFeedbackIcon />}
|
||||
icon={<DislikeFeedback size={16} />}
|
||||
onClick={() => handleFeedback("dislike")}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
{regenerate && (
|
||||
<RegenerateOption
|
||||
onHoverChange={setIsRegenerateHovered}
|
||||
selectedAssistant={currentPersona!}
|
||||
regenerate={regenerate}
|
||||
alternateModel={alternateModel}
|
||||
/>
|
||||
)}
|
||||
</TooltipGroup>
|
||||
</div>
|
||||
) : (
|
||||
@@ -547,31 +586,63 @@ export const AIMessage = ({
|
||||
ref={hoverElementRef}
|
||||
className={`
|
||||
absolute -bottom-4
|
||||
invisible ${(isHovering || settings?.isMobile) && "!visible"}
|
||||
opacity-0 ${(isHovering || settings?.isMobile) && "!opacity-100"}
|
||||
invisible ${(isHovering || isRegenerateHovered || settings?.isMobile) && "!visible"}
|
||||
opacity-0 ${(isHovering || isRegenerateHovered || settings?.isMobile) && "!opacity-100"}
|
||||
translate-y-2 ${(isHovering || settings?.isMobile) && "!translate-y-0"}
|
||||
transition-transform duration-300 ease-in-out
|
||||
flex md:flex-row gap-x-0.5 bg-background-125/40 p-1.5 rounded-lg
|
||||
`}
|
||||
>
|
||||
<TooltipGroup>
|
||||
<div className="flex justify-start w-full gap-x-0.5">
|
||||
{includeMessageSwitcher && (
|
||||
<div className="-mx-1 mr-auto">
|
||||
<MessageSwitcher
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() => {
|
||||
onMessageSelection(
|
||||
otherMessagesCanSwitchTo[
|
||||
currentMessageInd - 1
|
||||
]
|
||||
);
|
||||
}}
|
||||
handleNext={() => {
|
||||
onMessageSelection(
|
||||
otherMessagesCanSwitchTo[
|
||||
currentMessageInd + 1
|
||||
]
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy!">
|
||||
<CopyButton content={content.toString()} />
|
||||
</CustomTooltip>
|
||||
|
||||
<CustomTooltip showTick line content="Good response!">
|
||||
<HoverableIcon
|
||||
icon={<LikeFeedbackIcon />}
|
||||
icon={<LikeFeedback />}
|
||||
onClick={() => handleFeedback("like")}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
|
||||
<CustomTooltip showTick line content="Bad response!">
|
||||
<HoverableIcon
|
||||
icon={<DislikeFeedbackIcon />}
|
||||
icon={<DislikeFeedback size={16} />}
|
||||
onClick={() => handleFeedback("dislike")}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
{regenerate && (
|
||||
<RegenerateOption
|
||||
selectedAssistant={currentPersona!}
|
||||
regenerate={regenerate}
|
||||
alternateModel={alternateModel}
|
||||
onHoverChange={setIsRegenerateHovered}
|
||||
/>
|
||||
)}
|
||||
</TooltipGroup>
|
||||
</div>
|
||||
))}
|
||||
@@ -623,6 +694,7 @@ export const HumanMessage = ({
|
||||
onEdit,
|
||||
onMessageSelection,
|
||||
shared,
|
||||
stopGenerating = () => null,
|
||||
}: {
|
||||
shared?: boolean;
|
||||
content: string;
|
||||
@@ -631,6 +703,7 @@ export const HumanMessage = ({
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
onEdit?: (editedContent: string) => void;
|
||||
onMessageSelection?: (messageId: number) => void;
|
||||
stopGenerating?: () => void;
|
||||
}) => {
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
@@ -677,7 +750,6 @@ export const HumanMessage = ({
|
||||
<div className="xl:ml-8">
|
||||
<div className="flex flex-col mr-4">
|
||||
<FileDisplay alignBubble files={files || []} />
|
||||
|
||||
<div className="flex justify-end">
|
||||
<div className="w-full ml-8 flex w-full max-w-message-max break-words">
|
||||
{isEditing ? (
|
||||
@@ -700,24 +772,23 @@ export const HumanMessage = ({
|
||||
<textarea
|
||||
ref={textareaRef}
|
||||
className={`
|
||||
m-0
|
||||
w-full
|
||||
h-auto
|
||||
shrink
|
||||
border-0
|
||||
rounded-lg
|
||||
overflow-y-hidden
|
||||
bg-background-emphasis
|
||||
whitespace-normal
|
||||
break-word
|
||||
overscroll-contain
|
||||
outline-none
|
||||
placeholder-gray-400
|
||||
resize-none
|
||||
pl-4
|
||||
overflow-y-auto
|
||||
pr-12
|
||||
py-4`}
|
||||
m-0
|
||||
w-full
|
||||
h-auto
|
||||
shrink
|
||||
border-0
|
||||
rounded-lg
|
||||
bg-background-emphasis
|
||||
whitespace-normal
|
||||
break-word
|
||||
overscroll-contain
|
||||
outline-none
|
||||
placeholder-gray-400
|
||||
resize-none
|
||||
pl-4
|
||||
overflow-y-auto
|
||||
pr-12
|
||||
py-4`}
|
||||
aria-multiline
|
||||
role="textarea"
|
||||
value={editedContent}
|
||||
@@ -857,16 +928,18 @@ export const HumanMessage = ({
|
||||
<MessageSwitcher
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() =>
|
||||
handlePrevious={() => {
|
||||
stopGenerating();
|
||||
onMessageSelection(
|
||||
otherMessagesCanSwitchTo[currentMessageInd - 1]
|
||||
)
|
||||
}
|
||||
handleNext={() =>
|
||||
);
|
||||
}}
|
||||
handleNext={() => {
|
||||
stopGenerating();
|
||||
onMessageSelection(
|
||||
otherMessagesCanSwitchTo[currentMessageInd + 1]
|
||||
)
|
||||
}
|
||||
);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -4,9 +4,21 @@ import {
|
||||
} from "@/components/BasicClickable";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { ChevronDownIcon, InfoIcon } from "@/components/icons/icons";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
DanswerDocument,
|
||||
FilteredDanswerDocument,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import { FiCheck, FiEdit2, FiSearch, FiX } from "react-icons/fi";
|
||||
import { DownChevron } from "react-select/dist/declarations/src/components/indicators";
|
||||
|
||||
export function ShowHideDocsButton({
|
||||
messageId,
|
||||
@@ -37,26 +49,35 @@ export function ShowHideDocsButton({
|
||||
|
||||
export function SearchSummary({
|
||||
query,
|
||||
hasDocs,
|
||||
filteredDocs,
|
||||
finished,
|
||||
messageId,
|
||||
isCurrentlyShowingRetrieved,
|
||||
handleShowRetrieved,
|
||||
docs,
|
||||
toggleDocumentSelection,
|
||||
handleSearchQueryEdit,
|
||||
}: {
|
||||
toggleDocumentSelection?: () => void;
|
||||
docs?: DanswerDocument[] | null;
|
||||
filteredDocs: FilteredDanswerDocument[];
|
||||
finished: boolean;
|
||||
query: string;
|
||||
hasDocs: boolean;
|
||||
messageId: number | null;
|
||||
isCurrentlyShowingRetrieved: boolean;
|
||||
handleShowRetrieved: (messageId: number | null) => void;
|
||||
handleSearchQueryEdit?: (query: string) => void;
|
||||
}) {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [finalQuery, setFinalQuery] = useState(query);
|
||||
const [isOverflowed, setIsOverflowed] = useState(false);
|
||||
const searchingForRef = useRef<HTMLDivElement>(null);
|
||||
const editQueryRef = useRef<HTMLInputElement>(null);
|
||||
const [isDropdownOpen, setIsDropdownOpen] = useState(false);
|
||||
const searchingForRef = useRef<HTMLDivElement | null>(null);
|
||||
const editQueryRef = useRef<HTMLInputElement | null>(null);
|
||||
|
||||
const settings = useContext(SettingsContext);
|
||||
|
||||
const uniqueSourceTypes = Array.from(
|
||||
new Set((docs || []).map((doc) => doc.source_type))
|
||||
).slice(0, 3);
|
||||
|
||||
const toggleDropdown = () => {
|
||||
setIsDropdownOpen(!isDropdownOpen);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const checkOverflow = () => {
|
||||
@@ -70,7 +91,7 @@ export function SearchSummary({
|
||||
};
|
||||
|
||||
checkOverflow();
|
||||
window.addEventListener("resize", checkOverflow); // Recheck on window resize
|
||||
window.addEventListener("resize", checkOverflow);
|
||||
|
||||
return () => window.removeEventListener("resize", checkOverflow);
|
||||
}, []);
|
||||
@@ -88,15 +109,30 @@ export function SearchSummary({
|
||||
}, [query]);
|
||||
|
||||
const searchingForDisplay = (
|
||||
<div className={`flex p-1 rounded ${isOverflowed && "cursor-default"}`}>
|
||||
<FiSearch className="flex-none mr-2 my-auto" size={14} />
|
||||
<div
|
||||
className={`${!finished && "loading-text"}
|
||||
!text-sm !line-clamp-1 !break-all px-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
{finished ? "Searched" : "Searching"} for: <i> {finalQuery}</i>
|
||||
</div>
|
||||
<div
|
||||
className={`flex my-auto items-center ${isOverflowed && "cursor-default"}`}
|
||||
>
|
||||
{finished ? (
|
||||
<>
|
||||
<div
|
||||
onClick={() => {
|
||||
toggleDropdown();
|
||||
}}
|
||||
className={` transition-colors duration-300 group-hover:text-text-toolhover cursor-pointer text-text-toolrun !line-clamp-1 !break-all pr-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
Searched {filteredDocs.length > 0 && filteredDocs.length} document
|
||||
{filteredDocs.length != 1 && "s"} for {query}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div
|
||||
className={`loading-text !text-sm !line-clamp-1 !break-all px-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
Searching for: <i> {finalQuery}</i>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -147,43 +183,127 @@ export function SearchSummary({
|
||||
</div>
|
||||
</div>
|
||||
) : null;
|
||||
const SearchBlock = ({ doc, ind }: { doc: DanswerDocument; ind: number }) => {
|
||||
return (
|
||||
<div
|
||||
onClick={() => {
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={doc.document_id}
|
||||
className={`flex items-start gap-3 px-4 py-3 text-token-text-secondary ${ind == 0 && "rounded-t-xl"} hover:bg-background-100 group relative text-sm`}
|
||||
// className="w-full text-sm flex transition-all duration-500 hover:bg-background-125 bg-text-100 py-4 border-b"
|
||||
>
|
||||
<div className="mt-1 scale-[.9] flex-none">
|
||||
{doc.is_internet ? (
|
||||
<InternetSearchIcon url={doc.link} />
|
||||
) : (
|
||||
<SourceIcon sourceType={doc.source_type} iconSize={18} />
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-col">
|
||||
<a
|
||||
href={doc.link}
|
||||
target="_blank"
|
||||
className="line-clamp-1 text-text-900"
|
||||
>
|
||||
{/* <Citation link={doc.link} index={ind + 1} /> */}
|
||||
<p className="shrink truncate ellipsis break-all ">
|
||||
{doc.semantic_identifier || doc.document_id}
|
||||
</p>
|
||||
<p className="line-clamp-3 text-text-500 break-words">
|
||||
{doc.blurb}
|
||||
</p>
|
||||
</a>
|
||||
{/* <div className="flex overscroll-x-scroll mt-.5">
|
||||
<DocumentMetadataBlock document={doc} />
|
||||
</div> */}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
{isEditing ? (
|
||||
editInput
|
||||
) : (
|
||||
<>
|
||||
<div className="text-sm">
|
||||
{isOverflowed ? (
|
||||
<HoverPopup
|
||||
mainContent={searchingForDisplay}
|
||||
popupContent={
|
||||
<div>
|
||||
<b>Full query:</b>{" "}
|
||||
<div className="mt-1 italic">{query}</div>
|
||||
</div>
|
||||
}
|
||||
direction="top"
|
||||
/>
|
||||
) : (
|
||||
searchingForDisplay
|
||||
<>
|
||||
<div className="flex gap-x-2 group">
|
||||
{isEditing ? (
|
||||
editInput
|
||||
) : (
|
||||
<>
|
||||
<div className="my-auto text-sm">
|
||||
{isOverflowed ? (
|
||||
<HoverPopup
|
||||
mainContent={searchingForDisplay}
|
||||
popupContent={
|
||||
<div>
|
||||
<b>Full query:</b>{" "}
|
||||
<div className="mt-1 italic">{query}</div>
|
||||
</div>
|
||||
}
|
||||
direction="top"
|
||||
/>
|
||||
) : (
|
||||
searchingForDisplay
|
||||
)}
|
||||
</div>
|
||||
{handleSearchQueryEdit && (
|
||||
<Tooltip delayDuration={1000} content={"Edit Search"}>
|
||||
<button
|
||||
className="my-auto cursor-pointer rounded"
|
||||
onClick={() => {
|
||||
setIsEditing(true);
|
||||
}}
|
||||
>
|
||||
<FiEdit2 />
|
||||
</button>
|
||||
</Tooltip>
|
||||
)}
|
||||
</div>
|
||||
{handleSearchQueryEdit && (
|
||||
<Tooltip delayDuration={1000} content={"Edit Search"}>
|
||||
<button
|
||||
className="my-auto hover:bg-hover p-1.5 rounded"
|
||||
<button
|
||||
className="my-auto invisible group-hover:visible transition-all duration-300 hover:bg-hover rounded"
|
||||
onClick={toggleDropdown}
|
||||
>
|
||||
<ChevronDownIcon
|
||||
className={`transform transition-transform ${isDropdownOpen ? "rotate-180" : ""}`}
|
||||
/>
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isDropdownOpen && docs && docs.length > 0 && (
|
||||
<div
|
||||
className={`mt-2 -mx-8 w-full mb-4 flex relative transition-all duration-300 ${isDropdownOpen ? "opacity-100 max-h-[1000px]" : "opacity-0 max-h-0"}`}
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="mx-8 flex rounded overflow-hidden rounded-lg border-1.5 border divide-y divider-y-1.5 divider-y-border border-border flex-col gap-x-4">
|
||||
{!settings?.isMobile &&
|
||||
filteredDocs.length > 0 &&
|
||||
filteredDocs.map((doc, ind) => (
|
||||
<SearchBlock key={ind} doc={doc} ind={ind} />
|
||||
))}
|
||||
|
||||
<div
|
||||
onClick={() => {
|
||||
setIsEditing(true);
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={-1}
|
||||
className="cursor-pointer w-full flex transition-all duration-500 hover:bg-background-100 py-3 border-b"
|
||||
>
|
||||
<FiEdit2 />
|
||||
</button>
|
||||
</Tooltip>
|
||||
)}
|
||||
</>
|
||||
<div key={0} className="px-3 invisible scale-[.9] flex-none">
|
||||
<SourceIcon sourceType={"file"} iconSize={18} />
|
||||
</div>
|
||||
<div className="text-sm flex justify-between text-text-900">
|
||||
<p className="line-clamp-1">See context</p>
|
||||
<div className="flex gap-x-1"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -27,15 +27,12 @@ export function SkippedSearch({
|
||||
handleForceSearch: () => void;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex text-sm !pt-0 p-1">
|
||||
<div className="flex group-hover:text-text-900 text-text-500 text-sm !pt-0 p-1">
|
||||
<div className="flex mb-auto">
|
||||
<FiBook className="my-auto flex-none mr-2" size={14} />
|
||||
<div className="my-auto cursor-default">
|
||||
<span className="mobile:hidden">
|
||||
The AI decided this query didn't need a search
|
||||
</span>
|
||||
<span className="desktop:hidden">No search</span>
|
||||
</div>
|
||||
<span className="mobile:hidden">
|
||||
The AI decided this query didn't need a search
|
||||
</span>
|
||||
<span className="desktop:hidden">No search</span>
|
||||
</div>
|
||||
|
||||
<div className="ml-auto my-auto" onClick={handleForceSearch}>
|
||||
|
||||
@@ -26,6 +26,7 @@ import { CHAT_SESSION_ID_KEY, FOLDER_ID_KEY } from "@/lib/drag/constants";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { WarningCircle } from "@phosphor-icons/react";
|
||||
import { CustomTooltip } from "@/components/tooltip/CustomTooltip";
|
||||
import { NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH } from "@/lib/constants";
|
||||
|
||||
export function ChatSessionDisplay({
|
||||
chatSession,
|
||||
@@ -33,6 +34,7 @@ export function ChatSessionDisplay({
|
||||
isSelected,
|
||||
skipGradient,
|
||||
closeSidebar,
|
||||
stopGenerating = () => null,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
}: {
|
||||
@@ -43,6 +45,7 @@ export function ChatSessionDisplay({
|
||||
// if not set, the gradient will still be applied and cause weirdness
|
||||
skipGradient?: boolean;
|
||||
closeSidebar?: () => void;
|
||||
stopGenerating?: () => void;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
}) {
|
||||
@@ -99,6 +102,9 @@ export function ChatSessionDisplay({
|
||||
className="flex my-1 group relative"
|
||||
key={chatSession.id}
|
||||
onClick={() => {
|
||||
if (NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH) {
|
||||
stopGenerating();
|
||||
}
|
||||
if (settings?.isMobile && closeSidebar) {
|
||||
closeSidebar();
|
||||
}
|
||||
|
||||
@@ -40,6 +40,7 @@ interface HistorySidebarProps {
|
||||
reset?: () => void;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
stopGenerating?: () => void;
|
||||
}
|
||||
|
||||
export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
@@ -54,6 +55,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
openedFolders,
|
||||
toggleSidebar,
|
||||
removeToggle,
|
||||
stopGenerating = () => null,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
},
|
||||
@@ -179,6 +181,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
)}
|
||||
<div className="border-b border-border pb-4 mx-3" />
|
||||
<PagesTab
|
||||
stopGenerating={stopGenerating}
|
||||
newFolderId={newFolderId}
|
||||
showDeleteModal={showDeleteModal}
|
||||
showShareModal={showShareModal}
|
||||
|
||||
@@ -17,10 +17,12 @@ export function PagesTab({
|
||||
folders,
|
||||
openedFolders,
|
||||
closeSidebar,
|
||||
stopGenerating,
|
||||
newFolderId,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
}: {
|
||||
stopGenerating: () => void;
|
||||
page: pageType;
|
||||
existingChats?: ChatSession[];
|
||||
currentChatId?: number;
|
||||
@@ -125,6 +127,7 @@ export function PagesTab({
|
||||
return (
|
||||
<div key={`${chat.id}-${chat.name}`}>
|
||||
<ChatSessionDisplay
|
||||
stopGenerating={stopGenerating}
|
||||
showDeleteModal={showDeleteModal}
|
||||
showShareModal={showShareModal}
|
||||
closeSidebar={closeSidebar}
|
||||
|
||||
@@ -15,6 +15,7 @@ import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
|
||||
function BackToDanswerButton() {
|
||||
const router = useRouter();
|
||||
@@ -44,6 +45,7 @@ export function SharedChatDisplay({
|
||||
Prism.highlightAll();
|
||||
setIsReady(true);
|
||||
}, []);
|
||||
const { popup, setPopup } = usePopup();
|
||||
if (!chatSession) {
|
||||
return (
|
||||
<div className="min-h-full w-full">
|
||||
@@ -66,6 +68,7 @@ export function SharedChatDisplay({
|
||||
|
||||
return (
|
||||
<div className="w-full h-[100dvh] overflow-hidden">
|
||||
{popup}
|
||||
<div className="flex max-h-full overflow-hidden pb-[72px]">
|
||||
<div className="flex w-full overflow-hidden overflow-y-scroll">
|
||||
<div className="w-full h-full flex-col flex max-w-message-max mx-auto">
|
||||
@@ -95,6 +98,7 @@ export function SharedChatDisplay({
|
||||
} else {
|
||||
return (
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
shared
|
||||
currentPersona={currentPersona!}
|
||||
key={message.messageId}
|
||||
|
||||
97
web/src/app/chat/tools/ImageGeneratingAnimation.tsx
Normal file
97
web/src/app/chat/tools/ImageGeneratingAnimation.tsx
Normal file
@@ -0,0 +1,97 @@
|
||||
import React, { useState, useEffect, useRef } from "react";
|
||||
|
||||
export default function GeneratingImage({ isCompleted = false }) {
|
||||
const [progress, setProgress] = useState(0);
|
||||
const progressRef = useRef(0);
|
||||
const animationRef = useRef<number>();
|
||||
const startTimeRef = useRef<number>(Date.now());
|
||||
|
||||
useEffect(() => {
|
||||
let lastUpdateTime = 0;
|
||||
const updateInterval = 500; // Update at most every 500ms
|
||||
const animationDuration = 30000; // Animation will take 30 seconds to reach ~99%
|
||||
|
||||
const animate = (timestamp: number) => {
|
||||
const elapsedTime = timestamp - startTimeRef.current;
|
||||
|
||||
// Slower logarithmic curve
|
||||
const maxProgress = 99.9;
|
||||
const progress =
|
||||
maxProgress * (1 - Math.exp(-elapsedTime / animationDuration));
|
||||
|
||||
// Only update if enough time has passed since the last update
|
||||
if (timestamp - lastUpdateTime > updateInterval) {
|
||||
progressRef.current = progress;
|
||||
setProgress(Math.round(progress * 10) / 10); // Round to 1 decimal place
|
||||
lastUpdateTime = timestamp;
|
||||
}
|
||||
|
||||
if (!isCompleted && elapsedTime < animationDuration) {
|
||||
animationRef.current = requestAnimationFrame(animate);
|
||||
}
|
||||
};
|
||||
|
||||
startTimeRef.current = performance.now();
|
||||
animationRef.current = requestAnimationFrame(animate);
|
||||
|
||||
return () => {
|
||||
if (animationRef.current) {
|
||||
cancelAnimationFrame(animationRef.current);
|
||||
}
|
||||
};
|
||||
}, [isCompleted]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isCompleted) {
|
||||
if (animationRef.current) {
|
||||
cancelAnimationFrame(animationRef.current);
|
||||
}
|
||||
setProgress(100);
|
||||
}
|
||||
}, [isCompleted]);
|
||||
|
||||
return (
|
||||
<div className="object-cover object-center border border-neutral-200 bg-neutral-100 items-center justify-center overflow-hidden flex rounded-lg w-96 h-96 transition-opacity duration-300 opacity-100">
|
||||
<div className="m-auto relative flex">
|
||||
<svg className="w-16 h-16 transform -rotate-90" viewBox="0 0 100 100">
|
||||
<circle
|
||||
className="text-gray-200"
|
||||
strokeWidth="8"
|
||||
stroke="currentColor"
|
||||
fill="transparent"
|
||||
r="44"
|
||||
cx="50"
|
||||
cy="50"
|
||||
/>
|
||||
<circle
|
||||
className="text-gray-800 transition-all duration-300"
|
||||
strokeWidth="8"
|
||||
strokeDasharray={276.46}
|
||||
strokeDashoffset={276.46 * (1 - progress / 100)}
|
||||
strokeLinecap="round"
|
||||
stroke="currentColor"
|
||||
fill="transparent"
|
||||
r="44"
|
||||
cx="50"
|
||||
cy="50"
|
||||
/>
|
||||
</svg>
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<svg
|
||||
className="w-6 h-6 text-neutral-500 animate-pulse-strong"
|
||||
fill="none"
|
||||
viewBox="0 0 24 24"
|
||||
stroke="currentColor"
|
||||
>
|
||||
<path
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
strokeWidth="2"
|
||||
d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z"
|
||||
/>
|
||||
</svg>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
83
web/src/app/chat/tools/ImagePromptCitation.tsx
Normal file
83
web/src/app/chat/tools/ImagePromptCitation.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { CopyIcon } from "@/components/icons/icons";
|
||||
import { Divider } from "@tremor/react";
|
||||
import React, { forwardRef, useState } from "react";
|
||||
import { FiCheck } from "react-icons/fi";
|
||||
|
||||
interface PromptDisplayProps {
|
||||
prompt1: string;
|
||||
prompt2?: string;
|
||||
arg: string;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
}
|
||||
|
||||
const DualPromptDisplay = forwardRef<HTMLDivElement, PromptDisplayProps>(
|
||||
({ prompt1, prompt2, setPopup, arg }, ref) => {
|
||||
const [copied, setCopied] = useState<number | null>(null);
|
||||
|
||||
const copyToClipboard = (text: string, index: number) => {
|
||||
navigator.clipboard
|
||||
.writeText(text)
|
||||
.then(() => {
|
||||
setPopup({ message: "Copied to clipboard", type: "success" });
|
||||
setCopied(index);
|
||||
setTimeout(() => setCopied(null), 2000); // Reset copy status after 2 seconds
|
||||
})
|
||||
.catch((err) => {
|
||||
setPopup({ message: "Failed to copy", type: "error" });
|
||||
});
|
||||
};
|
||||
|
||||
const PromptSection = ({
|
||||
copied,
|
||||
prompt,
|
||||
index,
|
||||
}: {
|
||||
copied: number | null;
|
||||
prompt: string;
|
||||
index: number;
|
||||
}) => (
|
||||
<div className="w-full p-2 rounded-lg">
|
||||
<h2 className="text-lg font-semibold mb-2">
|
||||
{arg} {index + 1}
|
||||
</h2>
|
||||
|
||||
<p className="line-clamp-6 text-sm text-gray-800">{prompt}</p>
|
||||
|
||||
<button
|
||||
onMouseDown={() => copyToClipboard(prompt, index)}
|
||||
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
|
||||
>
|
||||
{copied != null ? (
|
||||
<>
|
||||
<FiCheck className="mr-2" size={16} />
|
||||
Copied!
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<CopyIcon className="mr-2" size={16} />
|
||||
Copy
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="w-[400px] bg-inverted mx-auto p-6 rounded-lg shadow-lg">
|
||||
<div className="flex flex-col gap-x-4">
|
||||
<PromptSection copied={copied} prompt={prompt1} index={0} />
|
||||
{prompt2 && (
|
||||
<>
|
||||
<Divider />
|
||||
<PromptSection copied={copied} prompt={prompt2} index={1} />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
DualPromptDisplay.displayName = "DualPromptDisplay";
|
||||
export default DualPromptDisplay;
|
||||
0
web/src/app/chat/tools/SearchAnimation.tsx
Normal file
0
web/src/app/chat/tools/SearchAnimation.tsx
Normal file
@@ -1 +1,6 @@
|
||||
export type FeedbackType = "like" | "dislike";
|
||||
export type ChatState = "input" | "loading" | "streaming" | "toolBuilding";
|
||||
export interface RegenerationState {
|
||||
regenerating: boolean;
|
||||
finalMessageIndex: number;
|
||||
}
|
||||
|
||||
@@ -119,6 +119,28 @@
|
||||
}
|
||||
}
|
||||
|
||||
@keyframes custom-spin {
|
||||
0% {
|
||||
transform: rotate(0deg);
|
||||
}
|
||||
25% {
|
||||
transform: rotate(180deg);
|
||||
}
|
||||
50% {
|
||||
transform: rotate(270deg);
|
||||
}
|
||||
75% {
|
||||
transform: rotate(315deg);
|
||||
}
|
||||
100% {
|
||||
transform: rotate(360deg);
|
||||
}
|
||||
}
|
||||
|
||||
.generating-spin {
|
||||
animation: custom-spin 2s cubic-bezier(0.4, 0.2, 0.6, 0.8) infinite;
|
||||
}
|
||||
|
||||
.collapsible {
|
||||
max-height: 300px;
|
||||
transition:
|
||||
@@ -223,3 +245,17 @@ code[class*="language-"] {
|
||||
.code-line .token.attr-name {
|
||||
color: theme("colors.token-attr-name");
|
||||
}
|
||||
|
||||
@keyframes pulse-strong {
|
||||
0%,
|
||||
100% {
|
||||
opacity: 0.9;
|
||||
}
|
||||
50% {
|
||||
opacity: 0.3;
|
||||
}
|
||||
}
|
||||
|
||||
.animate-pulse-strong {
|
||||
animation: pulse-strong 1.5s cubic-bezier(0.4, 0, 0.6, 1) infinite;
|
||||
}
|
||||
|
||||
@@ -320,15 +320,15 @@ export const DefaultDropdown = forwardRef<HTMLDivElement, DefaultDropdownProps>(
|
||||
const Content = (
|
||||
<div
|
||||
className={`
|
||||
flex
|
||||
text-sm
|
||||
bg-background
|
||||
px-3
|
||||
py-1.5
|
||||
rounded-lg
|
||||
border
|
||||
border-border
|
||||
cursor-pointer`}
|
||||
flex
|
||||
text-sm
|
||||
bg-background
|
||||
px-3
|
||||
py-1.5
|
||||
rounded-lg
|
||||
border
|
||||
border-border
|
||||
cursor-pointer`}
|
||||
>
|
||||
<p className="line-clamp-1">
|
||||
{selectedOption?.name ||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { IconProps } from "@tremor/react";
|
||||
import { IconType } from "react-icons";
|
||||
|
||||
const ICON_SIZE = 15;
|
||||
@@ -7,13 +6,22 @@ export const Hoverable: React.FC<{
|
||||
icon: IconType;
|
||||
onClick?: () => void;
|
||||
size?: number;
|
||||
}> = ({ icon, onClick, size = ICON_SIZE }) => {
|
||||
active?: boolean;
|
||||
hoverText?: string;
|
||||
}> = ({ icon: Icon, active, hoverText, onClick, size = ICON_SIZE }) => {
|
||||
return (
|
||||
<div
|
||||
className="hover:bg-hover p-1.5 rounded h-fit cursor-pointer"
|
||||
className={`group relative flex items-center overflow-hidden p-1.5 h-fit rounded-md cursor-pointer transition-all duration-300 ease-in-out hover:bg-hover`}
|
||||
onClick={onClick}
|
||||
>
|
||||
{icon({ size: size, className: "my-auto" })}
|
||||
<div className="flex items-center ">
|
||||
<Icon size={size} className="text-gray-600 shrink-0" />
|
||||
{hoverText && (
|
||||
<div className="max-w-0 leading-none whitespace-nowrap overflow-hidden transition-all duration-300 ease-in-out group-hover:max-w-xs group-hover:ml-2">
|
||||
<span className="text-xs text-gray-700">{hoverText}</span>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -9,6 +9,7 @@ import { useSortable } from "@dnd-kit/sortable";
|
||||
import React, { useState } from "react";
|
||||
import { FiBookmark } from "react-icons/fi";
|
||||
import { MdDragIndicator } from "react-icons/md";
|
||||
import { DragHandle } from "../table/DragHandle";
|
||||
|
||||
export const AssistantCard = ({
|
||||
assistant,
|
||||
@@ -107,10 +108,7 @@ export function DraggableAssistantCard(props: {
|
||||
style={style}
|
||||
className="overlow-y-scroll inputscroll flex items-center"
|
||||
>
|
||||
<div {...attributes} {...listeners} className="mr-1 cursor-grab">
|
||||
<MdDragIndicator className="h-3 w-3 flex-none" />
|
||||
</div>
|
||||
|
||||
<DragHandle {...attributes} {...listeners} />
|
||||
<AssistantCard {...props} />
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -710,6 +710,85 @@ export const ChevronIcon = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const StarFeedback = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<path
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="1.5"
|
||||
d="m12.495 18.587l4.092 2.15a1.044 1.044 0 0 0 1.514-1.106l-.783-4.552a1.045 1.045 0 0 1 .303-.929l3.31-3.226a1.043 1.043 0 0 0-.575-1.785l-4.572-.657A1.044 1.044 0 0 1 15 7.907l-2.088-4.175a1.044 1.044 0 0 0-1.88 0L8.947 7.907a1.044 1.044 0 0 1-.783.575l-4.51.657a1.044 1.044 0 0 0-.584 1.785l3.309 3.226a1.044 1.044 0 0 1 .303.93l-.783 4.55a1.044 1.044 0 0 0 1.513 1.107l4.093-2.15a1.043 1.043 0 0 1 .991 0"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const DislikeFeedback = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<g
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="1.5"
|
||||
>
|
||||
<path d="M5.75 2.75H4.568c-.98 0-1.775.795-1.775 1.776v8.284c0 .98.795 1.775 1.775 1.775h1.184c.98 0 1.775-.794 1.775-1.775V4.526c0-.98-.795-1.776-1.775-1.776" />
|
||||
<path d="m21.16 11.757l-1.42-7.101a2.368 2.368 0 0 0-2.367-1.906h-7.48a2.367 2.367 0 0 0-2.367 2.367v7.101a3.231 3.231 0 0 0 1.184 2.367l.982 5.918a.887.887 0 0 0 1.278.65l1.1-.543a3.551 3.551 0 0 0 1.87-4.048l-.496-1.965h5.396a2.368 2.368 0 0 0 2.32-2.84" />
|
||||
</g>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const LikeFeedback = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 24 24"
|
||||
>
|
||||
<g
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
stroke-width="1.5"
|
||||
>
|
||||
<path d="M5.75 9.415H4.568c-.98 0-1.775.794-1.775 1.775v8.284c0 .98.795 1.776 1.775 1.776h1.184c.98 0 1.775-.795 1.775-1.776V11.19c0-.98-.795-1.775-1.775-1.775" />
|
||||
<path d="m21.16 12.243l-1.42 7.101a2.367 2.367 0 0 1-2.367 1.906h-7.48a2.367 2.367 0 0 1-2.367-2.367v-7.101A3.231 3.231 0 0 1 8.71 9.415l.982-5.918a.888.888 0 0 1 1.278-.65l1.1.544a3.55 3.55 0 0 1 1.87 4.047l-.496 1.965h5.396a2.367 2.367 0 0 1 2.32 2.84" />
|
||||
</g>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const CheckmarkIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
@@ -1718,6 +1797,29 @@ export const FilledLikeIcon = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const StopGeneratingIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 14 14"
|
||||
>
|
||||
<path
|
||||
fill="currentColor"
|
||||
fill-rule="evenodd"
|
||||
d="M1.5 0A1.5 1.5 0 0 0 0 1.5v11A1.5 1.5 0 0 0 1.5 14h11a1.5 1.5 0 0 0 1.5-1.5v-11A1.5 1.5 0 0 0 12.5 0z"
|
||||
clip-rule="evenodd"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const LikeFeedbackIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
@@ -2581,3 +2683,40 @@ export const MinusIcon = ({
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const ToolCallIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 19 15"
|
||||
fill="none"
|
||||
>
|
||||
<path
|
||||
d="M4.42 0.75H2.8625H2.75C1.64543 0.75 0.75 1.64543 0.75 2.75V11.65C0.75 12.7546 1.64543 13.65 2.75 13.65H2.8625C2.8625 13.65 2.8625 13.65 2.8625 13.65C2.8625 13.65 4.00751 13.65 4.42 13.65M13.98 13.65H15.5375H15.65C16.7546 13.65 17.65 12.7546 17.65 11.65V2.75C17.65 1.64543 16.7546 0.75 15.65 0.75H15.5375H13.98"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M5.55283 4.21963C5.25993 3.92674 4.78506 3.92674 4.49217 4.21963C4.19927 4.51252 4.19927 4.9874 4.49217 5.28029L6.36184 7.14996L4.49217 9.01963C4.19927 9.31252 4.19927 9.7874 4.49217 10.0803C4.78506 10.3732 5.25993 10.3732 5.55283 10.0803L7.95283 7.68029C8.24572 7.3874 8.24572 6.91252 7.95283 6.61963L5.55283 4.21963Z"
|
||||
fill="currentColor"
|
||||
stroke="currentColor"
|
||||
strokeWidth="0.2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M9.77753 8.75003C9.3357 8.75003 8.97753 9.10821 8.97753 9.55003C8.97753 9.99186 9.3357 10.35 9.77753 10.35H13.2775C13.7194 10.35 14.0775 9.99186 14.0775 9.55003C14.0775 9.10821 13.7194 8.75003 13.2775 8.75003H9.77753Z"
|
||||
fill="currentColor"
|
||||
stroke="currentColor"
|
||||
strokeWidth="0.1"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -23,31 +23,17 @@ export function Citation({
|
||||
>
|
||||
<a
|
||||
onMouseDown={() => (link ? window.open(link, "_blank") : undefined)}
|
||||
className="cursor-pointer inline ml-1 align-middle"
|
||||
className="cursor-pointer inline ml-1 font-sans align-middle inline-block text-sm text-blue-500 cursor-help leading-none inline ml-1 align-middle"
|
||||
>
|
||||
<span className="group relative -top-1 text-sm text-gray-500 dark:text-gray-400 selection:bg-indigo-300 selection:text-black dark:selection:bg-indigo-900 dark:selection:text-white">
|
||||
<span
|
||||
className="inline-flex bg-background-200 group-hover:bg-background-300 items-center justify-center h-3.5 min-w-3.5 px-1 text-center text-xs rounded-full border-1 border-gray-400 ring-1 ring-gray-400 divide-gray-300 dark:divide-gray-700 dark:ring-gray-700 dark:border-gray-700 transition duration-150"
|
||||
data-number="3"
|
||||
>
|
||||
{innerText}
|
||||
</span>
|
||||
</span>
|
||||
[{innerText}]
|
||||
</a>
|
||||
</CustomTooltip>
|
||||
);
|
||||
} else {
|
||||
return (
|
||||
<CustomTooltip content={<div>This doc doesn't have a link!</div>}>
|
||||
<div className="inline-block cursor-help leading-none inline ml-1 align-middle">
|
||||
<span className="group relative -top-1 text-gray-500 dark:text-gray-400 selection:bg-indigo-300 selection:text-black dark:selection:bg-indigo-900 dark:selection:text-white">
|
||||
<span
|
||||
className="inline-flex bg-background-200 group-hover:bg-background-300 items-center justify-center h-3.5 min-w-3.5 flex-none px-1 text-center text-xs rounded-full border-1 border-gray-400 ring-1 ring-gray-400 divide-gray-300 dark:divide-gray-700 dark:ring-gray-700 dark:border-gray-700 transition duration-150"
|
||||
data-number="3"
|
||||
>
|
||||
{innerText}
|
||||
</span>
|
||||
</span>
|
||||
<div className="inline-block text-sm font-sans text-blue-500 cursor-help leading-none inline ml-1 align-middle">
|
||||
[{innerText}]
|
||||
</div>
|
||||
</CustomTooltip>
|
||||
);
|
||||
|
||||
@@ -4,9 +4,9 @@ import { MdDragIndicator } from "react-icons/md";
|
||||
export const DragHandle = (props: any) => {
|
||||
return (
|
||||
<div
|
||||
className={
|
||||
props.isDragging ? "hover:cursor-grabbing" : "hover:cursor-grab"
|
||||
}
|
||||
className={`mobile:hidden
|
||||
${props.isDragging ? "hover:cursor-grabbing" : "hover:cursor-grab"}
|
||||
`}
|
||||
{...props}
|
||||
>
|
||||
<MdDragIndicator />
|
||||
|
||||
@@ -44,8 +44,10 @@ export const CustomTooltip = ({
|
||||
wrap,
|
||||
showTick = false,
|
||||
delay = 500,
|
||||
maxWidth,
|
||||
position = "bottom",
|
||||
}: {
|
||||
maxWidth?: boolean;
|
||||
content: string | ReactNode;
|
||||
children: JSX.Element;
|
||||
large?: boolean;
|
||||
@@ -56,6 +58,7 @@ export const CustomTooltip = ({
|
||||
wrap?: boolean;
|
||||
citation?: boolean;
|
||||
position?: "top" | "bottom";
|
||||
className?: string;
|
||||
}) => {
|
||||
const [isVisible, setIsVisible] = useState(false);
|
||||
const [tooltipPosition, setTooltipPosition] = useState({ top: 0, left: 0 });
|
||||
|
||||
36
web/src/lib/chat/aiMessageSequence.ts
Normal file
36
web/src/lib/chat/aiMessageSequence.ts
Normal file
@@ -0,0 +1,36 @@
|
||||
// For handling AI message `sequences` (ie. ai messages which are streamed in sequence as separate messags but are in reality one message)
|
||||
|
||||
import { Message } from "@/app/chat/interfaces";
|
||||
import { DanswerDocument } from "../search/interfaces";
|
||||
|
||||
export function getConsecutiveAIMessagesAtEnd(
|
||||
messageHistory: Message[]
|
||||
): Message[] {
|
||||
const aiMessages = [];
|
||||
for (let i = messageHistory.length - 1; i >= 0; i--) {
|
||||
if (messageHistory[i]?.type === "assistant") {
|
||||
aiMessages.unshift(messageHistory[i]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return aiMessages;
|
||||
}
|
||||
export function getUniqueDocumentsFromAIMessages(
|
||||
messages: Message[]
|
||||
): DanswerDocument[] {
|
||||
const uniqueDocumentsMap = new Map<string, DanswerDocument>();
|
||||
|
||||
messages.forEach((message) => {
|
||||
if (message.documents) {
|
||||
message.documents.forEach((doc) => {
|
||||
const uniqueKey = `${doc.document_id}-${doc.chunk_ind}`;
|
||||
if (!uniqueDocumentsMap.has(uniqueKey)) {
|
||||
uniqueDocumentsMap.set(uniqueKey, doc);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return Array.from(uniqueDocumentsMap.values());
|
||||
}
|
||||
@@ -4,6 +4,11 @@ export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
|
||||
export const HEADER_HEIGHT = "h-16";
|
||||
export const SUB_HEADER = "h-12";
|
||||
|
||||
export const SIDEBAR_WIDTH_CONST = 425;
|
||||
export const MOBILE_SIDEBAR_WIDTH_CONST = 300;
|
||||
export const SIDEBAR_CARD_WIDTH = `w-[400px]`;
|
||||
export const MOBILE_SIDEBAR_CARD_WIDTH = `w-[275px]`;
|
||||
|
||||
export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080";
|
||||
export const NEXT_PUBLIC_DISABLE_STREAMING =
|
||||
process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true";
|
||||
@@ -24,9 +29,6 @@ export const GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME =
|
||||
export const SEARCH_TYPE_COOKIE_NAME = "search_type";
|
||||
export const AGENTIC_SEARCH_TYPE_COOKIE_NAME = "agentic_type";
|
||||
|
||||
export const SIDEBAR_WIDTH_CONST = "350px";
|
||||
export const SIDEBAR_WIDTH = `w-[350px]`;
|
||||
|
||||
export const LOGOUT_DISABLED =
|
||||
process.env.NEXT_PUBLIC_DISABLE_LOGOUT?.toLowerCase() === "true";
|
||||
|
||||
@@ -35,6 +37,9 @@ export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
|
||||
|
||||
export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors";
|
||||
|
||||
export const NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH =
|
||||
process.env.NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH?.toLowerCase() === "true";
|
||||
|
||||
/* Enterprise-only settings */
|
||||
|
||||
// NOTE: this should ONLY be used on the server-side. If used client side,
|
||||
|
||||
@@ -18,6 +18,9 @@ export type SearchType = (typeof SearchType)[keyof typeof SearchType];
|
||||
export interface AnswerPiecePacket {
|
||||
answer_piece: string;
|
||||
}
|
||||
export interface DelimiterPiece {
|
||||
delimiter: boolean;
|
||||
}
|
||||
|
||||
export interface ErrorMessagePacket {
|
||||
error: string;
|
||||
|
||||
@@ -5,6 +5,7 @@ import {
|
||||
import {
|
||||
AnswerPiecePacket,
|
||||
DanswerDocument,
|
||||
DelimiterPiece,
|
||||
DocumentInfoPacket,
|
||||
ErrorMessagePacket,
|
||||
Quote,
|
||||
@@ -91,6 +92,7 @@ export const searchRequestStreamed = async ({
|
||||
| DocumentInfoPacket
|
||||
| LLMRelevanceFilterPacket
|
||||
| BackendMessage
|
||||
| DelimiterPiece
|
||||
| RelevanceChunk
|
||||
>(decoder.decode(value, { stream: true }), previousPartialChunk);
|
||||
if (!completedChunks.length && !partialChunk) {
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import { PacketType } from "@/app/chat/lib";
|
||||
|
||||
type NonEmptyObject = { [k: string]: any };
|
||||
|
||||
const processSingleChunk = <T extends NonEmptyObject>(
|
||||
@@ -75,3 +77,33 @@ export async function* handleStream<T extends NonEmptyObject>(
|
||||
yield await Promise.resolve(completedChunks);
|
||||
}
|
||||
}
|
||||
|
||||
export async function* handleSSEStream<T extends PacketType>(
|
||||
streamingResponse: Response
|
||||
): AsyncGenerator<T, void, unknown> {
|
||||
const reader = streamingResponse.body?.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const rawChunk = await reader?.read();
|
||||
if (!rawChunk) {
|
||||
throw new Error("Unable to process chunk");
|
||||
}
|
||||
const { done, value } = rawChunk;
|
||||
if (done) {
|
||||
break;
|
||||
}
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split("\n").filter((line) => line.trim() !== "");
|
||||
|
||||
for (const line of lines) {
|
||||
try {
|
||||
const data = JSON.parse(line) as T;
|
||||
yield data;
|
||||
} catch (error) {
|
||||
console.error("Error parsing SSE data:", error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -90,10 +90,11 @@ module.exports = {
|
||||
"background-200": "#e5e5e5", // neutral-200
|
||||
"background-300": "#d4d4d4", // neutral-300
|
||||
"background-400": "#a3a3a3", // neutral-400
|
||||
"background-500": "#737373", // neutral-400
|
||||
"background-600": "#525252", // neutral-400
|
||||
"background-700": "#404040", // neutral-400
|
||||
"background-800": "#262626", // neutral-800
|
||||
"background-400": "#a3a3a3", // neutral-500
|
||||
"background-500": "#737373", // darkMedium, neutral-500
|
||||
"background-600": "#525252", // dark, neutral-600
|
||||
"background-700": "#404040", // solid, neutral-700
|
||||
"background-800": "#262626", // solidDark, neutral-800
|
||||
"background-900": "#111827", // gray-900
|
||||
"background-inverted": "#000000", // black
|
||||
|
||||
@@ -110,6 +111,10 @@ module.exports = {
|
||||
"text-600": "#525252", // dark, neutral-600
|
||||
"text-700": "#404040", // solid, neutral-700
|
||||
"text-800": "#262626", // solidDark, neutral-800
|
||||
"text-900": "#171717", // dark, neutral-900
|
||||
|
||||
"text-toolrun": "#a3a3a3", // darkMedium, neutral-500
|
||||
"text-toolhover": "#171717", // dark, neutral-900
|
||||
|
||||
subtle: "#6b7280", // gray-500
|
||||
default: "#4b5563", // gray-600
|
||||
|
||||
Reference in New Issue
Block a user