Compare commits

...

50 Commits

Author SHA1 Message Date
pablodanswer
4eb53ce56f rebase needs fixing 2024-08-19 07:40:53 -07:00
pablodanswer
2fc84ed63e post rebase fix 2024-08-18 16:41:12 -07:00
pablodanswer
722d5e6e54 add sequential tool calls 2024-08-18 16:40:07 -07:00
pablodanswer
14c30d2e4d add env variable 2024-08-18 15:05:44 -07:00
pablodanswer
6abad2fdd3 robust chat session state persistence 2024-08-18 15:05:44 -07:00
pablodanswer
4691e736f6 functional new message carry-over 2024-08-18 15:05:44 -07:00
pablodanswer
5a826a527f properly reset blank screen 2024-08-18 15:05:44 -07:00
pablodanswer
f92d31df70 refactored for stop / regenerate 2024-08-18 15:05:44 -07:00
pablodanswer
1eb786897a proper margin 2024-08-18 15:05:26 -07:00
pablodanswer
72471f9e1d remove parameter 2024-08-18 15:05:26 -07:00
pablodanswer
49c335d06a squash 2024-08-18 15:05:26 -07:00
pablodanswer
fda06b7739 more robust implementation for first messages 2024-08-18 15:05:26 -07:00
pablodanswer
00d44e31b3 validated + cleaner UI 2024-08-18 15:05:26 -07:00
pablodanswer
2a42c1dd18 functional once again post rebase but quite ugly 2024-08-18 15:05:26 -07:00
pablodanswer
05cd25043e add regenerate 2024-08-18 15:05:26 -07:00
pablodanswer
abebff50bb Enable seeding of analytics via file path (#2146)
* enable seeding of analytics via file path

* remove log
2024-08-18 15:05:26 -07:00
pablodanswer
0a7e672832 add handling for poorly formatting model names (#2143) 2024-08-18 15:05:26 -07:00
pablodanswer
221ab9134c add critical error just in case 2024-08-18 15:03:04 -07:00
pablodanswer
f7134202b6 slightly more specific logs 2024-08-18 14:44:10 -07:00
pablodanswer
bea11dc3aa include logs 2024-08-18 14:33:45 -07:00
pablodanswer
374b798071 update typing 2024-08-17 13:51:52 -07:00
pablodanswer
6a2e3edfcd add synchronous wrapper to avoid hampering main event loop 2024-08-17 13:39:22 -07:00
pablodanswer
2ef1731e32 tiny formatting (remove newline) 2024-08-17 09:29:39 -07:00
pablodanswer
7d4d7a5f5d clean final message handling 2024-08-17 01:14:31 -07:00
pablodanswer
ea2f9cf625 cleaner messages 2024-08-15 17:17:03 -07:00
pablodanswer
97dc9c5e31 add back stack trace detail 2024-08-15 16:46:32 -07:00
pablodanswer
249bcd46d9 clearer 2024-08-15 16:10:56 -07:00
pablodanswer
f29b727bc7 remove comments 2024-08-15 16:10:56 -07:00
pablodanswer
31fb6c0753 improve clarity + new SSE handling utility function 2024-08-15 16:10:56 -07:00
pablodanswer
a45e72c298 update utility + copy 2024-08-15 16:10:56 -07:00
pablodanswer
157548817c slightly more robust chat state 2024-08-15 16:10:56 -07:00
pablodanswer
d9396f77d1 remove false comment 2024-08-15 16:10:56 -07:00
pablodanswer
7bae6bbf8f remove log 2024-08-15 16:10:56 -07:00
pablodanswer
1d535769ed robustify 2024-08-15 16:10:56 -07:00
pablodanswer
8584a81fe2 unnecessary list removed 2024-08-15 16:10:56 -07:00
pablodanswer
5f4ac19928 robustify typing 2024-08-15 16:10:56 -07:00
pablodanswer
d898e4f738 remove logs 2024-08-15 16:10:56 -07:00
pablodanswer
19412f0aa0 add ChatState for more robust handling 2024-08-15 16:10:56 -07:00
pablodanswer
c338de30fd add new loading state to prevent collisions 2024-08-15 16:10:56 -07:00
pablodanswer
edfde621b9 formatting 2024-08-15 16:10:56 -07:00
pablodanswer
9306abf911 migrate to streaming response 2024-08-15 16:10:56 -07:00
pablodanswer
70d885b621 cleaner loop + data persistence 2024-08-15 16:10:56 -07:00
pablodanswer
53bea4f859 robustify frontend handling 2024-08-15 16:10:55 -07:00
pablodanswer
a79d734d96 typing 2024-08-15 16:10:28 -07:00
pablodanswer
25cd7de147 remove logs 2024-08-15 16:10:28 -07:00
pablodanswer
ab2916c807 robustify switching 2024-08-15 16:10:28 -07:00
pablodanswer
96112f1f95 functional rework of temporary user/assistant ID 2024-08-15 16:10:28 -07:00
pablodanswer
54502b32d3 remove logs 2024-08-15 16:10:28 -07:00
pablodanswer
9431e6c06c remove commits 2024-08-15 16:10:28 -07:00
pablodanswer
f18571d580 functional types + sidebar 2024-08-15 16:10:28 -07:00
64 changed files with 3024 additions and 756 deletions

View File

@@ -0,0 +1,59 @@
"""migrate tool calls
Revision ID: eb690a089310
Revises: ee3f4b47fad5
Create Date: 2024-08-04 17:07:47.533051
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "eb690a089310"
down_revision = "ee3f4b47fad5"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the new column
op.add_column(
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_tool_call",
"chat_message",
"tool_call",
["tool_call_id"],
["id"],
)
# Migrate existing data
op.execute(
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
)
# Drop the old relationship
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "message_id")
def downgrade() -> None:
# Add back the old column
op.add_column(
"tool_call",
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.create_foreign_key(
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
)
# Migrate data back
op.execute(
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
)
# Drop the new column
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "tool_call_id")

View File

@@ -0,0 +1,28 @@
"""Added alternate model to chat message
Revision ID: ee3f4b47fad5
Revises: 4a951134c801
Create Date: 2024-08-12 00:11:50.915845
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "ee3f4b47fad5"
down_revision = "4a951134c801"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_message",
sa.Column("alternate_model", sa.String(length=255), nullable=True),
)
def downgrade() -> None:
op.drop_column("chat_message", "alternate_model")

View File

@@ -36,9 +36,11 @@ def create_chat_chain(
chat_session_id: int,
db_session: Session,
prefetch_tool_calls: bool = True,
parent_id: int | None = None,
) -> tuple[ChatMessage, list[ChatMessage]]:
"""Build the linear chain of messages without including the root message"""
mainline_messages: list[ChatMessage] = []
all_chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session_id,
user_id=None,
@@ -60,7 +62,7 @@ def create_chat_chain(
current_message: ChatMessage | None = root_message
while current_message is not None:
child_msg = current_message.latest_child_message
if not child_msg:
if not child_msg or (parent_id and current_message.id == parent_id):
break
current_message = id_to_msg.get(child_msg)

View File

@@ -64,6 +64,10 @@ class DocumentRelevance(BaseModel):
relevance_summaries: dict[str, RelevanceAnalysis]
class Delimiter(BaseModel):
delimiter: bool
class DanswerAnswerPiece(BaseModel):
# A small piece of a complete answer. Used for streaming back answers.
answer_piece: str | None # if None, specifies the end of an Answer
@@ -76,6 +80,11 @@ class CitationInfo(BaseModel):
document_id: str
class MessageResponseIDInfo(BaseModel):
user_message_id: int | None
reserved_assistant_message_id: int
class StreamingError(BaseModel):
error: str
stack_trace: str | None = None
@@ -137,6 +146,7 @@ AnswerQuestionPossibleReturn = (
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| Delimiter
)

View File

@@ -9,8 +9,10 @@ from danswer.chat.chat_utils import create_chat_chain
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import Delimiter
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LLMRelevanceFilterResponse
from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import BING_API_KEY
@@ -27,6 +29,7 @@ from danswer.db.chat import get_chat_session_by_id
from danswer.db.chat import get_db_search_doc_by_id
from danswer.db.chat import get_doc_query_identifiers_from_model
from danswer.db.chat import get_or_create_root_message
from danswer.db.chat import reserve_message_id
from danswer.db.chat import translate_db_message_to_chat_message_detail
from danswer.db.chat import translate_db_search_doc_to_server_search_doc
from danswer.db.embedding_model import get_current_db_embedding_model
@@ -88,7 +91,7 @@ from danswer.tools.search.search_tool import SearchTool
from danswer.tools.search.search_tool import SECTION_RELEVANCE_LIST_ID
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallMetadata
from danswer.tools.utils import compute_all_tool_tokens
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
@@ -241,6 +244,8 @@ ChatPacket = (
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| MessageResponseIDInfo
| Delimiter
)
ChatPacketStream = Iterator[ChatPacket]
@@ -256,9 +261,9 @@ def stream_chat_message_objects(
max_document_percentage: float = CHAT_TARGET_CHUNK_PERCENTAGE,
# if specified, uses the last user message and does not create a new user message based
# on the `new_msg_req.message`. Currently, requires a state where the last message is a
# user message (e.g. this can only be used for the chat-seeding flow).
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> ChatPacketStream:
"""Streams in order:
1. [conditional] Retrieved documents if a search needs to be run
@@ -342,7 +347,15 @@ def stream_chat_message_objects(
parent_message = root_message
user_message = None
if not use_existing_user_message:
if new_msg_req.regenerate:
final_msg, history_msgs = create_chat_chain(
parent_id=parent_id,
chat_session_id=chat_session_id,
db_session=db_session,
)
elif not use_existing_user_message:
# Create new message at the right place in the tree and update the parent's child pointer
# Don't commit yet until we verify the chat message chain
user_message = create_new_chat_message(
@@ -451,13 +464,29 @@ def stream_chat_message_objects(
use_sections=new_msg_req.chunks_above > 0
or new_msg_req.chunks_below > 0,
)
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=user_message.id
if user_message is not None
else parent_message.id,
message_type=MessageType.ASSISTANT,
)
yield MessageResponseIDInfo(
user_message_id=user_message.id if user_message else None,
reserved_assistant_message_id=reserved_message_id,
)
alternate_model = (
new_msg_req.llm_override.model_version if new_msg_req.llm_override else None
)
# Cannot determine these without the LLM step or breaking out early
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=final_msg,
prompt_id=prompt_id,
alternate_model=alternate_model,
# message=,
# rephrased_query=,
# token_count=,
@@ -581,9 +610,11 @@ def stream_chat_message_objects(
document_pruning_config.using_tool_message = explicit_tool_calling_supported(
llm_provider, llm_model_name
)
tool_has_been_called = False # TODO remove
# LLM prompt building, response capturing, etc.
answer = Answer(
is_connected=is_connected,
question=final_msg.message,
latest_query_files=latest_query_files,
answer_style_config=AnswerStyleConfig(
@@ -617,8 +648,11 @@ def stream_chat_message_objects(
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
tool_has_been_called = True
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
@@ -689,9 +723,76 @@ def stream_chat_message_objects(
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
if isinstance(packet, Delimiter):
db_citations = None
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
if tool_result is None:
tool_call = None
else:
tool_call = ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query
if qa_docs_response
else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_call=tool_call,
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
yield Delimiter(delimiter=True)
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message,
prompt_id=prompt_id,
# message=,
# rephrased_query=,
# token_count=,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
# error=,
# reference_docs=,
db_session=db_session,
commit=False,
)
else:
if isinstance(packet, ToolCallMetadata):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
error_msg = str(e)
logger.exception(f"Failed to process chat message: {error_msg}")
@@ -703,54 +804,56 @@ def stream_chat_message_objects(
db_session.rollback()
return
# Post-LLM answer processing
try:
db_citations = None
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
if not tool_has_been_called:
try:
db_citations = None
if reference_db_search_docs:
db_citations = translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
gen_ai_response_message = partial_response(
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_calls=[
ToolCall(
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=(
qa_docs_response.rephrased_query if qa_docs_response else None
),
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=db_citations,
error=None,
tool_call=ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
if tool_result
else [],
)
db_session.commit() # actually save user / assistant message
if tool_result
else None,
)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
logger.debug("Committing messages")
db_session.commit() # actually save user / assistant message
yield msg_detail_response
except Exception as e:
logger.exception(e)
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
yield msg_detail_response
except Exception as e:
error_msg = str(e)
logger.exception(error_msg)
# Frontend will erase whatever answer and show this instead
yield StreamingError(error="Failed to parse LLM output")
@log_generator_function_time()
@@ -759,6 +862,7 @@ def stream_chat_message(
user: User | None,
use_existing_user_message: bool = False,
litellm_additional_headers: dict[str, str] | None = None,
is_connected: Callable[[], bool] | None = None,
) -> Iterator[str]:
with get_session_context_manager() as db_session:
objects = stream_chat_message_objects(
@@ -767,6 +871,7 @@ def stream_chat_message(
db_session=db_session,
use_existing_user_message=use_existing_user_message,
litellm_additional_headers=litellm_additional_headers,
is_connected=is_connected,
)
for obj in objects:
yield get_json_line(obj.dict())

View File

@@ -42,7 +42,8 @@ prompts:
task: >
Generate an image based on the user's description.
Provide a detailed description of the generated image, including key elements, colors, and composition.
Provide a detailed description of the generated image, including key elements, colors, and composition.
If the request is not possible or appropriate, explain why and suggest alternatives.
datetime_aware: true

View File

@@ -13,6 +13,9 @@ DEFAULT_BOOST = 0
SESSION_KEY = "session"
# For tool calling
MAXIMUM_TOOL_CALL_SEQUENCE = 5
# For chunking/processing chunks
RETURN_SEPARATOR = "\n\r\n"
SECTION_SEPARATOR = "\n\n"

View File

@@ -143,7 +143,7 @@ def handle_standard_answers(
parent_message=root_message,
prompt_id=prompt.id if prompt else None,
message=query_msg.message,
token_count=0,
token_count=10,
message_type=MessageType.USER,
db_session=db_session,
commit=True,

View File

@@ -36,7 +36,7 @@ from danswer.search.models import RetrievalDocs
from danswer.search.models import SavedSearchDoc
from danswer.search.models import SearchDoc as ServerSearchDoc
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallMetadata
from danswer.utils.logger import setup_logger
@@ -147,8 +147,14 @@ def delete_search_doc_message_relationship(
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
db_session.execute(stmt)
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
)
if chat_message and chat_message.tool_call_id:
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
db_session.execute(stmt)
chat_message.tool_call_id = None
db_session.commit()
@@ -350,7 +356,7 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
stmt = stmt.options(joinedload(ChatMessage.tool_call))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -393,6 +399,34 @@ def get_or_create_root_message(
return new_root_message
def reserve_message_id(
db_session: Session,
chat_session_id: int,
parent_message: int,
message_type: MessageType,
) -> int:
# Create an empty chat message
empty_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message,
latest_child_message=None,
message="",
token_count=0,
message_type=message_type,
)
# Add the empty message to the session
db_session.add(empty_message)
# Flush the session to get an ID for the new chat message
db_session.flush()
# Get the ID of the newly created message
new_id = empty_message.id
return new_id
def create_new_chat_message(
chat_session_id: int,
parent_message: ChatMessage,
@@ -408,38 +442,62 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
tool_call: ToolCall | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
alternate_model: str | None = None,
) -> ChatMessage:
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message.id,
latest_child_message=None,
message=message,
rephrased_query=rephrased_query,
prompt_id=prompt_id,
token_count=token_count,
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
error=error,
alternate_assistant_id=alternate_assistant_id,
)
if reserved_message_id is not None:
# Edit existing message
existing_message = db_session.query(ChatMessage).get(reserved_message_id)
if existing_message is None:
raise ValueError(f"No message found with id {reserved_message_id}")
existing_message.chat_session_id = chat_session_id
existing_message.parent_message = parent_message.id
existing_message.message = message
existing_message.rephrased_query = rephrased_query
existing_message.prompt_id = prompt_id
existing_message.token_count = token_count
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_call = tool_call
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.alternate_model = alternate_model
new_chat_message = existing_message
else:
# Create new message
new_chat_message = ChatMessage(
chat_session_id=chat_session_id,
parent_message=parent_message.id,
latest_child_message=None,
message=message,
rephrased_query=rephrased_query,
prompt_id=prompt_id,
token_count=token_count,
message_type=message_type,
citations=citations,
files=files,
tool_call=tool_call,
error=error,
alternate_assistant_id=alternate_assistant_id,
alternate_model=alternate_model,
)
db_session.add(new_chat_message)
# SQL Alchemy will propagate this to update the reference_docs' foreign keys
if reference_docs:
new_chat_message.search_docs = reference_docs
db_session.add(new_chat_message)
# Flush the session to get an ID for the new chat message
db_session.flush()
parent_message.latest_child_message = new_chat_message.id
if commit:
db_session.commit()
return new_chat_message
@@ -656,15 +714,15 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
tool_call=ToolCallMetadata(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
alternate_assistant_id=chat_message.alternate_assistant_id,
alternate_model=chat_message.alternate_model,
)
return chat_msg_detail

View File

@@ -764,10 +764,11 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
message: Mapped["ChatMessage"] = relationship(
"ChatMessage", back_populates="tool_calls"
"ChatMessage",
back_populates="tool_call",
uselist=False,
foreign_keys="ChatMessage.tool_call_id",
)
@@ -848,9 +849,13 @@ class ChatMessage(Base):
Integer, ForeignKey("persona.id"), nullable=True
)
alternate_model: Mapped[str | None] = mapped_column(String, nullable=True)
parent_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
latest_child_message: Mapped[int | None] = mapped_column(Integer, nullable=True)
message: Mapped[str] = mapped_column(Text)
tool_call_id: Mapped[int | None] = mapped_column(
ForeignKey("tool_call.id"), nullable=True
)
rephrased_query: Mapped[str] = mapped_column(Text, nullable=True)
# If None, then there is no answer generation, it's the special case of only
# showing the user the retrieved docs
@@ -893,9 +898,8 @@ class ChatMessage(Base):
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
tool_call: Mapped["ToolCall"] = relationship(
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",

View File

@@ -1,4 +1,6 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
from uuid import uuid4
@@ -6,12 +8,19 @@ from langchain.schema.messages import BaseMessage
from langchain_core.messages import AIMessageChunk
from langchain_core.messages import HumanMessage
from danswer.chat.chat_utils import llm_doc_from_inference_section
from danswer.chat.models import AnswerQuestionPossibleReturn
from danswer.chat.models import AnswerQuestionStreamReturn
from danswer.chat.models import CitationInfo
from danswer.chat.models import CustomToolResponse
from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import DanswerContexts
from danswer.chat.models import DanswerQuotes
from danswer.chat.models import Delimiter
from danswer.chat.models import ImageGenerationDisplay
from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.constants import MessageType
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
@@ -51,20 +60,21 @@ from danswer.tools.message import ToolCallSummary
from danswer.tools.search.search_tool import FINAL_CONTEXT_DOCUMENTS
from danswer.tools.search.search_tool import SEARCH_DOC_CONTENT_ID
from danswer.tools.search.search_tool import SEARCH_RESPONSE_SUMMARY_ID
from danswer.tools.search.search_tool import SearchResponseSummary
from danswer.tools.search.search_tool import SearchTool
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.tools.tool_runner import (
check_which_tools_should_run_for_non_tool_calling_llm,
)
from danswer.tools.tool_runner import ToolCallFinalResult
from danswer.tools.tool_runner import ToolCallKickoff
from danswer.tools.tool_runner import ToolCallMetadata
from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
# DanswerQuotes | CitationInfo | DanswerContexts | ImageGenerationDisplay | CustomToolResponse | StreamingError
logger = setup_logger()
@@ -115,6 +125,8 @@ class Answer:
# Returns the full document sections text from the search tool
return_contexts: bool = False,
skip_gen_ai_answer_generation: bool = False,
is_connected: Callable[[], bool] | None = None,
tool_uses: dict | None = None,
) -> None:
if single_message_history and message_history:
raise ValueError(
@@ -122,6 +134,7 @@ class Answer:
)
self.question = question
self.is_connected: Callable[[], bool] | None = is_connected
self.latest_query_files = latest_query_files or []
self.file_id_to_file = {file.file_id: file for file in (files or [])}
@@ -132,12 +145,17 @@ class Answer:
self.skip_explicit_tool_calling = skip_explicit_tool_calling
self.message_history = message_history or []
self.tool_uses = tool_uses or {}
# used for QA flow where we only want to send a single message
self.single_message_history = single_message_history
self.answer_style_config = answer_style_config
self.prompt_config = prompt_config
self.current_streamed_output: list = []
self.processing_stream: list = []
self.final_context_docs: list = []
self.llm = llm
self.llm_tokenizer = get_tokenizer(
provider_type=llm.config.model_provider,
@@ -153,6 +171,7 @@ class Answer:
self._return_contexts = return_contexts
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
@@ -174,6 +193,7 @@ class Answer:
),
)
)
elif self.answer_style_config.quotes_config:
prompt_builder.update_user_prompt(
build_quotes_user_message(
@@ -184,11 +204,219 @@ class Answer:
)
)
def _raw_output_for_explicit_tool_calling_llms_loop(
self,
) -> Iterator[
str | ToolCallKickoff | ToolResponse | ToolCallMetadata | Delimiter | Any
]:
count = 1
maximum_count = 4
while count <= maximum_count:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
count += 1
print(f"COUNT IS {count}")
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
tool_call_chunk = AIMessageChunk(content="")
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
print(f"-----------------------\nThe current prompt is: {prompt}")
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
tool_choice="required" if self.force_use_tool.force_use else None,
):
if isinstance(message, AIMessageChunk) and (
message.tool_call_chunks or message.tool_calls
):
if tool_call_chunk is None:
tool_call_chunk = message
else:
tool_call_chunk += message # type: ignore
else:
if tool_call_chunk is None and count != 2:
print("Skipping the tool call + message compeltely")
return
elif message.content:
yield cast(str, message.content)
if not tool_call_chunk:
print(
"Skipping the tool call but generated message due to lack of existing tool call messages"
)
return
tool_call_requests = tool_call_chunk.tool_calls
logger.critical(
f"-------------------TOOL CALL REQUESTS ({len(tool_call_requests)})-------------------"
)
for tool_call_request in tool_call_requests:
known_tools_by_name = [
tool
for tool in self.tools
if tool.name == tool_call_request["name"]
]
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
print("my tool call request is htis")
print(tool_args)
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
tool_responses = list(tool_runner.tool_responses())
yield from tool_responses
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
# Update message history with tool call and response
self.message_history.append(
PreviousMessage(
message=str(tool_call_request),
message_type=MessageType.ASSISTANT,
token_count=10, # You may want to implement a token counting method
tool_call=None,
files=[],
)
)
self.message_history.append(
PreviousMessage(
message="\n".join(str(response) for response in tool_responses),
message_type=MessageType.SYSTEM,
token_count=10,
tool_call=None,
files=[],
)
)
# Generate response based on updated message history
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
print("-------------")
print("NEW PROMPT")
print(prompt)
print("\n\n\n\n\n\n-------\n\n------")
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=self.final_context_docs or [],
doc_id_to_rank_map=map_document_id_order(
self.final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
response_stream = process_answer_stream_fn(
message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
)
)
response_content = ""
for chunk in response_stream:
response_content += (
chunk.answer_piece
if hasattr(chunk, "answer_piece")
and chunk.answer_piece is not None
else str(chunk)
)
yield chunk
yield "FINAL TOKEN"
# Update message history with LLM response
self.message_history.append(
PreviousMessage(
message=response_content,
message_type=MessageType.ASSISTANT,
token_count=10,
tool_call=None,
files=[], # You may want to implement a token counting method
)
)
def _raw_output_for_explicit_tool_calling_llms(
self,
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
) -> Iterator[
str
| ToolCallKickoff
| ToolResponse
| ToolCallMetadata
| AnswerQuestionPossibleReturn
| str
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
# special things we need to keep track of for the SearchTool
# search_results: list[
# LlmDoc
# ] | None = None # raw results that will be displayed to the user
# final_context_docs: list[
# LlmDoc
# ] | None = None # processed docs to feed into the LLM
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
@@ -235,6 +463,8 @@ class Answer:
tool_call_chunk += message # type: ignore
else:
if message.content:
if self.is_cancelled:
return
yield cast(str, message.content)
if not tool_call_chunk:
@@ -290,20 +520,49 @@ class Answer:
)
)
yield tool_runner.tool_final_result()
# Streaming response
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
yield from message_generator_to_string_generator(
for token in message_generator_to_string_generator(
self.llm.stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
)
):
if self.is_cancelled:
return
yield token
return
# ADD BACK IN
# process_answer_stream_fn = _get_answer_stream_processor(
# context_docs=[],
# # if doc selection is enabled, then search_results will be None,
# # so we need to use the final_context_docs
# doc_id_to_rank_map=map_document_id_order([]),
# answer_style_configs=self.answer_style_config,
# )
# yield from process_answer_stream_fn(
# message_generator_to_string_generator(self.llm.stream(prompt=prompt))
# )
def _raw_output_for_non_explicit_tool_calling_llms(
self,
) -> Iterator[str | ToolCallKickoff | ToolResponse | ToolCallFinalResult]:
) -> Iterator[
str
| ToolCallKickoff
| ToolResponse
| ToolCallMetadata
| AnswerQuestionStreamReturn
| DanswerAnswerPiece
| DanswerQuotes
| CitationInfo
| DanswerContexts
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| Delimiter
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
@@ -378,9 +637,13 @@ class Answer:
)
)
prompt = prompt_builder.build()
yield from message_generator_to_string_generator(
for token in message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
)
):
if self.is_cancelled:
return
yield token
return
tool, tool_args = chosen_tool_and_args
@@ -430,11 +693,28 @@ class Answer:
)
)
final = tool_runner.tool_final_result()
yield final
# Streaming response
prompt = prompt_builder.build()
yield from message_generator_to_string_generator(self.llm.stream(prompt=prompt))
for token in message_generator_to_string_generator(
self.llm.stream(prompt=prompt)
):
if self.is_cancelled:
return
yield token
# Add this back in
# process_answer_stream_fn = _get_answer_stream_processor(
# context_docs=final_context_documents or [],
# # if doc selection is enabled, then search_results will be None,
# # so we need to use the final_context_docs
# doc_id_to_rank_map=map_document_id_order(final_context_documents or []),
# answer_style_configs=self.answer_style_config,
# )
# yield from process_answer_stream_fn(
# message_generator_to_string_generator(self.llm.stream(prompt=prompt))
# )
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -443,97 +723,102 @@ class Answer:
return
output_generator = (
self._raw_output_for_explicit_tool_calling_llms()
self._raw_output_for_explicit_tool_calling_llms_loop()
if explicit_tool_calling_supported(
self.llm.config.model_provider, self.llm.config.model_name
)
and not self.skip_explicit_tool_calling
else self._raw_output_for_non_explicit_tool_calling_llms()
)
print(f" output generator {output_generator} ")
self.processing_stream = []
def _process_stream(
stream: Iterator[ToolCallKickoff | ToolResponse | str],
stream: Iterator[
str
| ToolCallKickoff
| ToolResponse
| ToolCallMetadata
| Delimiter
| Any
],
) -> AnswerStream:
message = None
# special things we need to keep track of for the SearchTool
search_results: list[
LlmDoc
] | None = None # raw results that will be displayed to the user
final_context_docs: list[
LlmDoc
] | None = None # processed docs to feed into the LLM
for message in stream:
if isinstance(message, ToolCallKickoff) or isinstance(
message, ToolCallFinalResult
message, ToolCallMetadata
):
yield message
elif isinstance(message, ToolResponse):
if message.id == SEARCH_RESPONSE_SUMMARY_ID:
pass
# We don't need to run section merging in this flow, this variable is only used
# below to specify the ordering of the documents for the purpose of matching
# citations to the right search documents. The deduplication logic is more lightweight
# there and we don't need to do it twice
search_results = [
llm_doc_from_inference_section(section)
for section in cast(
SearchResponseSummary, message.response
).top_sections
]
# search_results = [
# llm_doc_from_inference_section(section)
# for section in cast(
# SearchResponseSummary, message.response
# ).top_sections
# ]
elif message.id == FINAL_CONTEXT_DOCUMENTS:
final_context_docs = cast(list[LlmDoc], message.response)
self.final_context_docs = final_context_docs
elif (
message.id == SEARCH_DOC_CONTENT_ID
and not self._return_contexts
):
continue
yield message
else:
# assumes all tool responses will come first, then the final answer
break
if message == "FINAL TOKEN":
self.current_streamed_output = self.processing_stream
self.processing_stream = []
yield Delimiter(delimiter=True)
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
elif isinstance(message, str):
yield DanswerAnswerPiece(answer_piece=str(message))
else:
yield message
def _stream() -> Iterator[str]:
if message:
yield cast(str, message)
yield from cast(Iterator[str], stream)
yield from process_answer_stream_fn(_stream())
processed_stream = []
for processed_packet in _process_stream(output_generator):
processed_stream.append(processed_packet)
self.processing_stream.append(processed_packet)
yield processed_packet
self._processed_stream = processed_stream
self._processed_stream = self.processing_stream
@property
def llm_answer(self) -> str:
answer = ""
for packet in self.processed_streamed_output:
if not self._processed_stream and not self.current_streamed_output:
return ""
for packet in self.current_streamed_output or self._processed_stream or []:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
return answer
@property
def citations(self) -> list[CitationInfo]:
citations: list[CitationInfo] = []
for packet in self.processed_streamed_output:
for packet in self.current_streamed_output:
if isinstance(packet, CitationInfo):
citations.append(packet)
return citations
@property
def is_cancelled(self) -> bool:
if self._is_cancelled:
return True
if self.is_connected is not None:
if not self.is_connected():
logger.debug("Answer stream has been cancelled")
self._is_cancelled = not self.is_connected()
return self._is_cancelled

View File

@@ -16,7 +16,7 @@ from danswer.configs.constants import MessageType
from danswer.file_store.models import InMemoryChatFile
from danswer.llm.override_models import PromptOverride
from danswer.llm.utils import build_content_with_imgs
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallMetadata
if TYPE_CHECKING:
from danswer.db.models import ChatMessage
@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallMetadata | None
@classmethod
def from_chat_message(
@@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
tool_call=ToolCallMetadata(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
)
def to_langchain_msg(self) -> BaseMessage:

View File

@@ -133,6 +133,10 @@ class AnswerPromptBuilder:
)
)
return drop_messages_history_overflow(
response = drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens
)
for msg in response:
print(f"{msg.type} : \t \t ||||||||| {msg.content[:20]}")
return response

View File

@@ -9,6 +9,7 @@ from requests import Timeout
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.llm.utils import convert_lm_input_to_basic_string
@@ -73,6 +74,12 @@ class CustomModelServer(LLM):
response_content = json.loads(response.content).get("generated_text", "")
return AIMessage(content=response_content)
@classmethod
def create_prompt(cls, message: PreviousMessage) -> str:
# TODO improve / iterate
return f'I searched for some things! """thigns that I searched for!: {message.message}"""'
def log_model_configs(self) -> None:
logger.debug(f"Custom model at: {self._endpoint}")

View File

@@ -0,0 +1,14 @@
from danswer.configs.constants import MessageType
from danswer.llm.answering.models import PreviousMessage
def create_previous_message(
assistant_content: str, token_count: int
) -> PreviousMessage:
return PreviousMessage(
message=assistant_content,
message_type=MessageType.ASSISTANT,
token_count=token_count,
files=[],
tool_call=None,
)

View File

@@ -10,11 +10,11 @@ import litellm # type: ignore
import tiktoken
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import AIMessage
from langchain.schema import BaseMessage
from langchain.schema import HumanMessage
from langchain.schema import PromptValue
from langchain.schema.language_model import LanguageModelInput
from langchain.schema.messages import AIMessage
from langchain.schema.messages import BaseMessage
from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from litellm.exceptions import APIConnectionError # type: ignore
from litellm.exceptions import APIError # type: ignore
@@ -46,6 +46,27 @@ if TYPE_CHECKING:
logger = setup_logger()
# def translate_danswer_msg_to_langchain(
# msg: Union[ChatMessage, "PreviousMessage"]
# ) -> BaseMessage:
# files: list[InMemoryChatFile] = []
# # If the message is a `ChatMessage`, it doesn't have the downloaded files
# # attached. Just ignore them for now. Also, OpenAI doesn't allow files to
# # be attached to AI messages, so we must remove them
# if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
# files = msg.files
# content = build_content_with_imgs(msg.message, files)
# if msg.message_type == MessageType.SYSTEM:
# print("SYSTE MESAGE")
# print(msg.message)
# # raise ValueError("System messages are not currently part of history")
# if msg.message_type == MessageType.ASSISTANT:
# return AIMessage(content=content)
# if msg.message_type == MessageType.USER:
# return HumanMessage(content=content)
def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
error_msg = str(e)
@@ -100,21 +121,38 @@ def litellm_exception_to_error_msg(e: Exception, llm: LLM) -> str:
def translate_danswer_msg_to_langchain(
msg: Union[ChatMessage, "PreviousMessage"],
msg: Union[ChatMessage, "PreviousMessage"], token_count: int
) -> BaseMessage:
files: list[InMemoryChatFile] = []
# If the message is a `ChatMessage`, it doesn't have the downloaded files
# attached. Just ignore them for now. Also, OpenAI doesn't allow files to
# be attached to AI messages, so we must remove them
if not isinstance(msg, ChatMessage) and msg.message_type != MessageType.ASSISTANT:
files = msg.files
content = build_content_with_imgs(msg.message, files)
content = msg.message
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
return SystemMessage(content=content)
wrapped_content = ""
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
try:
parsed_content = json.loads(content)
if (
"name" in parsed_content
and parsed_content["name"] == "run_image_generation"
):
wrapped_content += f"I, the AI, am now generating an \
image based on the prompt: '{parsed_content['args']['prompt']}'\n"
wrapped_content += "[/AI IMAGE GENERATION REQUEST]"
elif (
"id" in parsed_content
and parsed_content["id"] == "image_generation_response"
):
wrapped_content += "I, the AI, have generated the following image(s) based on the previous request:\n"
for img in parsed_content["response"]:
wrapped_content += f"- Description: {img['revised_prompt']}\n"
wrapped_content += f" Image URL: {img['url']}\n\n"
wrapped_content += "[/AI IMAGE GENERATION RESPONSE]"
else:
wrapped_content = content
except json.JSONDecodeError:
wrapped_content = content
return AIMessage(content=wrapped_content)
if msg.message_type == MessageType.USER:
return HumanMessage(content=content)
@@ -124,15 +162,103 @@ def translate_danswer_msg_to_langchain(
def translate_history_to_basemessages(
history: list[ChatMessage] | list["PreviousMessage"],
) -> tuple[list[BaseMessage], list[int]]:
print("message history is")
new_history = []
assistant_content = None
token_count = 1
from danswer.llm.temporary import create_previous_message
from danswer.tools.tool import ToolRegistry
from danswer.llm.answering.models import PreviousMessage
for i, msg in enumerate(history):
message = cast(ChatMessage, msg)
if message.message_type != MessageType.ASSISTANT:
if assistant_content is not None:
combined_ai_message = create_previous_message(
assistant_content, token_count
)
assistant_content = None
new_history.append(combined_ai_message)
new_history.append(cast(PreviousMessage, message))
continue
if message.tool_call and message.tool_call.tool_name == "run_image_generation":
assistant_content = (assistant_content or "") + ToolRegistry.get_prompt(
"run_image_generation", message
)
# assistant_content = assistant_content or "" + f"I generated images with these descriptions! {message.message}"
else:
assistant_content = message.message
# TODO make better + fix token counting
if assistant_content is not None:
combined_ai_message = create_previous_message(assistant_content, token_count)
new_history.append(combined_ai_message)
history = new_history
for h in history:
print(f"\t\t{h.message_type}: \t\t|| {h.message[:100]}\n\n")
history_basemessages = [
translate_danswer_msg_to_langchain(msg)
translate_danswer_msg_to_langchain(msg, 0)
for msg in history
if msg.token_count != 0
]
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
# summary = "[CONVERSATION SUMMARY]\n"
# summary += "The most recent user request may involve generating additional images. "
# summary += "I should carefully review the conversation history an
# d the latest user request and almost defeinitely GENERATE MORE IMAGES "
# summary += "[/CONVERSATION SUMMARY]"
# history_basemessages.append(AIMessage(content=summary))
# history_token_counts.append(100)
# print()
for msg in history_basemessages:
print(f"{msg.type} : \t \t ||||||||| {msg.content[:20]}")
# print(f"{msg.type}: {msg.content[:20]}")
return history_basemessages, history_token_counts
# def translate_history_to_basemessages(
# history: Union[list[ChatMessage], list["PreviousMessage"]]
# ) -> tuple[list[BaseMessage], list[int]]:
# history_basemessages = []
# history_token_counts = []
# image_generation_count = 0
# for msg in history:
# if msg.token_count != 0:
# translated_msg = translate_danswer_msg_to_langchain(
# msg, image_generation_count
# )
# if (
# isinstance(translated_msg.content, str)
# and "[ImageGenerationRe" in translated_msg.content
# ):
# image_generation_count += 1
# history_basemessages.append(translated_msg)
# history_token_counts.append(msg.token_count)
# # Add a generic summary message at the end
# summary = "[CONVERSATION SUMMARY]\n"
# summary += "The most recent user request may involve generating additional images. "
# summary += "I should carefully review the conversation history and the latest user request "
# summary += (
# "to determine if any new images need to be generated, ensuring I don't repeat "
# )
# summary += "any image generations that have already been completed.\n"
# summary += f"I already generated {image_generation_count} images thus far. I should keep my responses EXTREMELY SHORT"
# summary += "[/CONVERSATION SUMMARY]"
# history_basemessages.append(AIMessage(content=summary))
# history_token_counts.append(100)
# return history_basemessages, history_token_counts
def _build_content(
message: str,
files: list[InMemoryChatFile] | None = None,
@@ -189,15 +315,15 @@ def build_content_with_imgs(
for file in files
if file.file_type == "image"
]
+ [
{
"type": "image_url",
"image_url": {
"url": url,
},
}
for url in img_urls
],
# + [
# {
# "type": "image_url",
# "image_url": {
# "url": url,
# },
# }
# for url in img_urls
# ],
)

View File

@@ -32,7 +32,7 @@ def check_if_need_search_multi_message(
return True
prompt_msgs: list[BaseMessage] = [SystemMessage(content=REQUIRE_SEARCH_SYSTEM_MSG)]
prompt_msgs.extend([translate_danswer_msg_to_langchain(msg) for msg in history])
prompt_msgs.extend([translate_danswer_msg_to_langchain(msg, 2) for msg in history])
last_query = query_message.message

View File

@@ -1,5 +1,8 @@
import asyncio
import io
import uuid
from collections.abc import Callable
from collections.abc import Generator
from fastapi import APIRouter
from fastapi import Depends
@@ -207,8 +210,6 @@ def rename_chat_session(
chat_session_id = rename_req.chat_session_id
user_id = user.id if user is not None else None
logger.info(f"Received rename request for chat session: {chat_session_id}")
if name:
update_chat_session(
db_session=db_session,
@@ -271,19 +272,39 @@ def delete_chat_session_by_id(
delete_chat_session(user_id, session_id, db_session)
async def is_disconnected(request: Request) -> Callable[[], bool]:
main_loop = asyncio.get_event_loop()
def is_disconnected_sync() -> bool:
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
try:
return not future.result(timeout=0.01)
except asyncio.TimeoutError:
logger.error("Asyncio timed out")
return True
except Exception as e:
error_msg = str(e)
logger.critical(
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
)
return True
return is_disconnected_sync
@router.post("/send-message")
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_user),
_: None = Depends(check_token_rate_limits),
is_disconnected_func: Callable[[], bool] = Depends(is_disconnected),
) -> StreamingResponse:
"""This endpoint is both used for all the following purposes:
- Sending a new message in the session
- Regenerating a message in the session (just send the same one again)
- Editing a message (similar to regenerating but sending a different message)
- Kicking off a seeded chat session (set `use_existing_user_message`)
To avoid extra overhead/latency, this assumes (and checks) that previous messages on the path
have already been set as latest"""
logger.debug(f"Received new chat message: {chat_message_req.message}")
@@ -295,15 +316,26 @@ def handle_new_chat_message(
):
raise HTTPException(status_code=400, detail="Empty chat message is invalid")
packets = stream_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
),
)
return StreamingResponse(packets, media_type="application/json")
import json
def stream_generator() -> Generator[str, None, None]:
try:
for packet in stream_chat_message(
new_msg_req=chat_message_req,
user=user,
use_existing_user_message=chat_message_req.use_existing_user_message,
litellm_additional_headers=get_litellm_additional_request_headers(
request.headers
),
is_connected=is_disconnected_func,
):
yield json.dumps(packet) if isinstance(packet, dict) else packet
except Exception as e:
logger.exception(f"Error in chat message streaming: {e}")
yield json.dumps({"error": str(e)})
return StreamingResponse(stream_generator(), media_type="text/event-stream")
@router.put("/set-message-as-latest")

View File

@@ -17,7 +17,7 @@ from danswer.search.models import ChunkContext
from danswer.search.models import RetrievalDetails
from danswer.search.models import SearchDoc
from danswer.search.models import Tag
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallMetadata
class SourceTag(Tag):
@@ -97,7 +97,9 @@ class CreateChatMessageRequest(ChunkContext):
# allows the caller to specify the exact search query they want to use
# will disable Query Rewording if specified
query_override: str | None = None
alternate_model: str | None = None # Added optional string for alternate model
regenerate: bool | None = None
# allows the caller to override the Persona / Prompt
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
@@ -179,11 +181,12 @@ class ChatMessageDetail(BaseModel):
message_type: MessageType
time_sent: datetime
alternate_assistant_id: str | None
alternate_model: str | None
# Dict mapping citation number to db_doc_id
chat_session_id: int | None = None
citations: dict[int, int] | None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallMetadata | None
def dict(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().dict(*args, **kwargs) # type: ignore

View File

@@ -60,6 +60,12 @@ class CustomTool(Tool):
"""For LLMs which support explicit tool calling"""
@classmethod
def create_prompt(cls, message: PreviousMessage) -> str:
# TODO improve / iterate
return f'I searched for some things! """thigns that I searched for!: {message.message}"""'
def tool_definition(self) -> dict:
return self._tool_definition

View File

@@ -17,6 +17,7 @@ from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import message_to_string
from danswer.prompts.constants import GENERAL_SEP_PAT
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolRegistry
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -31,7 +32,7 @@ YES_IMAGE_GENERATION = "Yes Image Generation"
SKIP_IMAGE_GENERATION = "Skip Image Generation"
IMAGE_GENERATION_TEMPLATE = f"""
Given the conversation history and a follow up query, determine if the system should call \
Given the conversation history and a follow up query, determine if the system should call
an external image generation tool to better answer the latest user input.
Your default response is {SKIP_IMAGE_GENERATION}.
@@ -62,6 +63,7 @@ class ImageShape(str, Enum):
LANDSCAPE = "landscape"
@ToolRegistry.register("run_image_generation")
class ImageGenerationTool(Tool):
_NAME = "run_image_generation"
_DESCRIPTION = "Generate an image from a prompt."
@@ -121,6 +123,11 @@ class ImageGenerationTool(Tool):
},
}
@classmethod
def create_prompt(cls, message: PreviousMessage) -> str:
# TODO improve / iterate
return f'I generated images with these descriptions! """Descriptions: {message.message}"""'
def get_args_for_non_tool_calling_llm(
self,
query: str,
@@ -210,7 +217,8 @@ class ImageGenerationTool(Tool):
in error_message
):
raise ValueError(
"The image generation request was rejected due to OpenAI's content policy. Please try a different prompt."
"The image generation request was rejected due to OpenAI's content policy."
+ "Please try a different prompt."
)
elif "Invalid image URL" in error_message:
raise ValueError("Invalid image URL provided for image generation.")

View File

@@ -6,7 +6,7 @@ from danswer.llm.utils import build_content_with_imgs
IMG_GENERATION_SUMMARY_PROMPT = """
You have just created the attached images in response to the following query: "{query}".
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
Can you please summarize them in a sentence or two? NEVER include image urls or bulleted lists.
"""

View File

@@ -119,6 +119,12 @@ class InternetSearchTool(Tool):
def display_name(self) -> str:
return self._DISPLAY_NAME
@classmethod
def create_prompt(cls, message: PreviousMessage) -> str:
# TODO improve / iterate
return f'I searched for some things! """thigns that I searched for!: {message.message}"""'
def tool_definition(self) -> dict:
return {
"type": "function",

View File

@@ -35,5 +35,5 @@ class ToolRunnerResponse(BaseModel):
return values
class ToolCallFinalResult(ToolCallKickoff):
class ToolCallMetadata(ToolCallKickoff):
tool_result: Any # we would like to use JSON_ro, but can't due to its recursive nature

View File

@@ -32,6 +32,7 @@ from danswer.secondary_llm_flows.choose_search import check_if_need_search
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
from danswer.tools.search.search_utils import llm_doc_to_dict
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolRegistry
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
@@ -65,6 +66,7 @@ HINT: if you are unfamiliar with the user input OR think the user input is a typ
"""
@ToolRegistry.register(SEARCH_RESPONSE_SUMMARY_ID)
class SearchTool(Tool):
_NAME = "run_search"
_DISPLAY_NAME = "Search Tool"
@@ -175,6 +177,12 @@ class SearchTool(Tool):
)
return {"query": rephrased_query}
@classmethod
def create_prompt(cls, message: PreviousMessage) -> str:
# TODO improve / iterate
return f'I searched for some things! """thigns that I searched for!: {message.message}"""'
"""Actual tool execution"""
def _build_response_for_specified_sections(

View File

@@ -1,12 +1,18 @@
import abc
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import Dict
from typing import Type
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.models import ToolResponse
# from danswer.llm.answering.models import ChatMessage
class Tool(abc.ABC):
@property
@@ -24,6 +30,11 @@ class Tool(abc.ABC):
def display_name(self) -> str:
raise NotImplementedError
@classmethod
@abc.abstractmethod
def create_prompt(self, message: PreviousMessage) -> str:
raise NotImplementedError
"""For LLMs which support explicit tool calling"""
@abc.abstractmethod
@@ -61,3 +72,29 @@ class Tool(abc.ABC):
It is the result that will be stored in the database.
"""
raise NotImplementedError
class ToolRegistry:
_registry: Dict[str, Type[Tool]] = {}
@classmethod
def register(cls, tool_id: str) -> Callable[[type[Tool]], type[Tool]]:
def decorator(tool_class: Type[Tool]) -> type[Tool]:
cls._registry[tool_id] = tool_class
return tool_class
return decorator
@classmethod
def get_tool(cls, tool_id: str) -> Type[Tool]:
if tool_id not in cls._registry:
raise ValueError(f"No tool registered with id: {tool_id}")
return cls._registry[tool_id]
@classmethod
def get_prompt(cls, tool_id: str, message: PreviousMessage) -> str:
if tool_id not in cls._registry:
raise ValueError(f"No tool registered with id: {tool_id}")
tool = cls._registry[tool_id]
new_prompt = tool.create_prompt(message=cast(PreviousMessage, message))
return new_prompt

View File

@@ -3,8 +3,8 @@ from typing import Any
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.models import ToolCallFinalResult
from danswer.tools.models import ToolCallKickoff
from danswer.tools.models import ToolCallMetadata
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -36,8 +36,8 @@ class ToolRunner:
tool_responses = list(self.tool_responses())
return self.tool.build_tool_message_content(*tool_responses)
def tool_final_result(self) -> ToolCallFinalResult:
return ToolCallFinalResult(
def tool_final_result(self) -> ToolCallMetadata:
return ToolCallMetadata(
tool_name=self.tool.name,
tool_args=self.args,
tool_result=self.tool.final_result(*self.tool_responses()),

View File

@@ -23,7 +23,6 @@ from ee.danswer.server.enterprise_settings.store import (
)
from ee.danswer.server.enterprise_settings.store import upload_logo
logger = setup_logger()
_SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION"
@@ -36,7 +35,8 @@ class SeedConfiguration(BaseModel):
personas: list[CreatePersonaRequest] | None = None
settings: Settings | None = None
enterprise_settings: EnterpriseSettings | None = None
analytics_script: AnalyticsScriptUpload | None = None
analytics_script_key: str | None = None
analytics_script_path: str | None = None
def _parse_env() -> SeedConfiguration | None:
@@ -119,10 +119,19 @@ def _seed_logo(db_session: Session, logo_path: str | None) -> None:
def _seed_analytics_script(seed_config: SeedConfiguration) -> None:
if seed_config.analytics_script is not None:
if seed_config.analytics_script_path and seed_config.analytics_script_key:
logger.info("Seeding analytics script")
try:
store_analytics_script(seed_config.analytics_script)
with open(seed_config.analytics_script_path, "r") as file:
script_content = file.read()
analytics_script = AnalyticsScriptUpload(
script=script_content, secret_key=seed_config.analytics_script_key
)
store_analytics_script(analytics_script)
except FileNotFoundError:
logger.error(
f"Analytics script file not found: {seed_config.analytics_script_path}"
)
except ValueError as e:
logger.error(f"Failed to seed analytics script: {str(e)}")
@@ -133,7 +142,6 @@ def get_seed_config() -> SeedConfiguration | None:
def seed_db() -> None:
seed_config = _parse_env()
if seed_config is None:
logger.info("No seeding configuration file passed")
return

View File

@@ -49,6 +49,12 @@ ENV NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PRED
ARG NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN
ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
ARG NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH
ENV NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH=${NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH}
ARG NEXT_PUBLIC_THEME
ENV NEXT_PUBLIC_THEME=${NEXT_PUBLIC_THEME}
@@ -112,6 +118,9 @@ ENV NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED=${NEXT_PUBLIC_DO_NOT_USE_T
ARG NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN
ENV NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN}
ARG NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH
ENV NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH=${NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH}
ARG NEXT_PUBLIC_DISABLE_LOGOUT
ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}

View File

@@ -14,26 +14,6 @@ const AddPromptSchema = Yup.object().shape({
});
const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => {
const defaultPrompts = [
{
title: "Email help",
prompt: "Write a professional email addressing the following points:",
},
{
title: "Code explanation",
prompt: "Explain the following code snippet in simple terms:",
},
{
title: "Product description",
prompt: "Write a compelling product description for the following item:",
},
{
title: "Troubleshooting steps",
prompt:
"Provide step-by-step troubleshooting instructions for the following issue:",
},
];
return (
<ModalWrapper onClose={onClose} modalClassName="max-w-xl">
<Formik

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,186 @@
import { useChatContext } from "@/components/context/ChatContext";
import {
getDisplayNameForModel,
LlmOverride,
useLlmOverride,
} from "@/lib/hooks";
import {
DefaultDropdownElement,
StringOrNumberOption,
} from "@/components/Dropdown";
import { Persona } from "@/app/admin/assistants/interfaces";
import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
import { useState } from "react";
import { Hoverable } from "@/components/Hoverable";
import { Popover } from "@/components/popover/Popover";
import { FiStar } from "react-icons/fi";
import { StarFeedback } from "@/components/icons/icons";
import { IconType } from "react-icons";
export function RegenerateDropdown({
options,
selected,
onSelect,
side,
maxHeight,
gptBox,
alternate,
}: {
alternate?: string;
options: StringOrNumberOption[];
selected: string | null;
onSelect: (value: string | number | null) => void;
includeDefault?: boolean;
side?: "top" | "right" | "bottom" | "left";
maxHeight?: string;
gptBox?: boolean;
}) {
const [isOpen, setIsOpen] = useState(false);
const Dropdown = (
<div
className={`
border
border
rounded-lg
flex
flex-col
mx-2
bg-background
${maxHeight || "max-h-96"}
overflow-y-auto
overscroll-contain relative`}
>
<p
className="
sticky
top-0
flex
bg-background
font-bold
px-3
text-sm
py-1.5
"
>
Pick a model
</p>
{options.map((option, ind) => {
const isSelected = option.value === selected;
return (
<DefaultDropdownElement
key={option.value}
name={getDisplayNameForModel(option.name)}
description={option.description}
onSelect={() => onSelect(option.value)}
isSelected={isSelected}
/>
);
})}
</div>
);
return (
<Popover
open={isOpen}
onOpenChange={(open) => setIsOpen(open)}
content={
<div onClick={() => setIsOpen(!isOpen)}>
{!alternate ? (
<Hoverable size={16} icon={StarFeedback as IconType} />
) : (
<Hoverable
size={16}
icon={StarFeedback as IconType}
hoverText={getDisplayNameForModel(alternate)}
/>
)}
</div>
}
popover={Dropdown}
align="start"
side={side}
sideOffset={5}
triggerMaxWidth
/>
);
}
export default function RegenerateOption({
selectedAssistant,
regenerate,
alternateModel,
onHoverChange,
}: {
selectedAssistant: Persona;
regenerate: (modelOverRide: LlmOverride) => Promise<void>;
alternateModel?: string;
onHoverChange: (isHovered: boolean) => void;
}) {
const llmOverrideManager = useLlmOverride();
const { llmProviders } = useChatContext();
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
const llmOptionsByProvider: {
[provider: string]: { name: string; value: string }[];
} = {};
const uniqueModelNames = new Set<string>();
llmProviders.forEach((llmProvider) => {
if (!llmOptionsByProvider[llmProvider.provider]) {
llmOptionsByProvider[llmProvider.provider] = [];
}
(llmProvider.display_model_names || llmProvider.model_names).forEach(
(modelName) => {
if (!uniqueModelNames.has(modelName)) {
uniqueModelNames.add(modelName);
llmOptionsByProvider[llmProvider.provider].push({
name: modelName,
value: structureValue(
llmProvider.name,
llmProvider.provider,
modelName
),
});
}
}
);
});
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
([provider, options]) => [...options]
);
const currentModelName =
llmOverrideManager?.llmOverride.modelName ||
(selectedAssistant
? selectedAssistant.llm_model_version_override || llmName
: llmName);
return (
<div
className="group flex items-center relative"
onMouseEnter={() => onHoverChange(true)}
onMouseLeave={() => onHoverChange(false)}
>
<RegenerateDropdown
alternate={alternateModel}
options={llmOptions}
selected={currentModelName}
onSelect={(value) => {
const { name, provider, modelName } = destructureValue(
value as string
);
regenerate({
name: name,
provider: provider,
modelName: modelName,
});
}}
/>
</div>
);
}

View File

@@ -9,6 +9,7 @@ import {
buildDocumentSummaryDisplay,
} from "@/components/search/DocumentDisplay";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { MOBILE_SIDEBAR_CARD_WIDTH, SIDEBAR_CARD_WIDTH } from "@/lib/constants";
interface DocumentDisplayProps {
document: DanswerDocument;
@@ -39,8 +40,8 @@ export function ChatDocumentDisplay({
return (
<div
key={document.semantic_identifier}
className={`p-2 w-[325px] justify-start rounded-md ${
isSelected ? "bg-background-200" : "bg-background-125"
className={`p-2 desktop:${SIDEBAR_CARD_WIDTH} mobile:${MOBILE_SIDEBAR_CARD_WIDTH} justify-start rounded-md ${
isSelected ? "bg-background-200" : "bg-background-150"
} text-sm mx-3`}
>
<div className="flex relative justify-start overflow-y-visible">

View File

@@ -3,12 +3,12 @@ import { Divider, Text } from "@tremor/react";
import { ChatDocumentDisplay } from "./ChatDocumentDisplay";
import { usePopup } from "@/components/admin/connectors/Popup";
import { removeDuplicateDocs } from "@/lib/documentUtils";
import { Message, RetrievalType } from "../interfaces";
import { Message } from "../interfaces";
import { ForwardedRef, forwardRef } from "react";
interface DocumentSidebarProps {
closeSidebar: () => void;
selectedMessage: Message | null;
currentDocuments: DanswerDocument[] | null;
selectedDocuments: DanswerDocument[] | null;
toggleDocumentSelection: (document: DanswerDocument) => void;
clearSelectedDocuments: () => void;
@@ -23,7 +23,7 @@ export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
(
{
closeSidebar,
selectedMessage,
currentDocuments,
selectedDocuments,
toggleDocumentSelection,
clearSelectedDocuments,
@@ -40,7 +40,6 @@ export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
const selectedDocumentIds =
selectedDocuments?.map((document) => document.document_id) || [];
const currentDocuments = selectedMessage?.documents || null;
const dedupedDocuments = removeDuplicateDocs(currentDocuments || []);
// NOTE: do not allow selection if less than 75 tokens are left
@@ -72,14 +71,8 @@ export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
{popup}
<div className="pl-3 mx-2 pr-6 mt-3 flex text-text-800 flex-col text-2xl text-emphasis flex font-semibold">
{dedupedDocuments.length} Documents
<p className="text-sm font-semibold flex flex-wrap gap-x-2 text-text-600 mt-1">
<p className="text-sm flex flex-wrap gap-x-2 font-normal text-text-700 mt-1">
Select to add to continuous context
<a
href="https://docs.danswer.dev/introduction"
className="underline cursor-pointer hover:text-strong"
>
Learn more
</a>
</p>
</div>
@@ -137,14 +130,14 @@ export const DocumentSidebar = forwardRef<HTMLDivElement, DocumentSidebarProps>(
<div className="absolute left-0 bottom-0 w-full bg-gradient-to-b from-neutral-100/0 via-neutral-100/40 backdrop-blur-xs to-neutral-100 h-[100px]" />
<div className="sticky bottom-4 w-full left-0 justify-center flex gap-x-4">
<button
className="bg-[#84e49e] text-xs p-2 rounded text-text-800"
className="bg-background-800 px-3 hover:bg-background-600 transition-background duration-300 py-2.5 scale-[.95] rounded text-text-200"
onClick={() => closeSidebar()}
>
Save Changes
</button>
<button
className="bg-error text-xs p-2 rounded text-text-200"
className="bg-background-125 hover:bg-background-150 transition-background duration-300 ring ring-1 ring-border scale-[.95] px-3 py-2.5 rounded text-text-900"
onClick={() => {
clearSelectedDocuments();

View File

@@ -41,7 +41,7 @@ export function FullImageModal({
<img
src={buildImgUrl(fileId)}
alt="Uploaded image"
className="max-w-full max-h-full"
className="max-w-full rounded-lg max-h-full"
/>
</Dialog.Content>
</Dialog.Portal>

View File

@@ -15,10 +15,11 @@ export function InMessageImage({ fileId }: { fileId: string }) {
/>
<div className="relative w-full h-full max-w-96 max-h-96">
{!imageLoaded && (
<div className="absolute inset-0 bg-gray-200 animate-pulse rounded-lg" />
)}
<div
className={`absolute inset-0 bg-gray-200 rounded-lg transition-opacity duration-300 ${
imageLoaded ? "opacity-0" : "opacity-100"
}`}
/>
<img
width={1200}
height={1200}

View File

@@ -21,6 +21,7 @@ import {
CpuIconSkeleton,
FileIcon,
SendIcon,
StopGeneratingIcon,
} from "@/components/icons/icons";
import { IconType } from "react-icons";
import Popup from "../../../components/popup/Popup";
@@ -31,6 +32,9 @@ import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { Hoverable } from "@/components/Hoverable";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { StopCircle } from "@phosphor-icons/react/dist/ssr";
import { Square } from "@phosphor-icons/react";
import { ChatState } from "../types";
const MAX_INPUT_HEIGHT = 200;
export function ChatInputBar({
@@ -39,10 +43,11 @@ export function ChatInputBar({
selectedDocuments,
message,
setMessage,
stopGenerating,
onSubmit,
isStreaming,
filterManager,
llmOverrideManager,
chatState,
// assistants
selectedAssistant,
@@ -59,6 +64,8 @@ export function ChatInputBar({
inputPrompts,
}: {
openModelSettings: () => void;
chatState: ChatState;
stopGenerating: () => void;
showDocs: () => void;
selectedDocuments: DanswerDocument[];
assistantOptions: Persona[];
@@ -68,7 +75,6 @@ export function ChatInputBar({
message: string;
setMessage: (message: string) => void;
onSubmit: () => void;
isStreaming: boolean;
filterManager: FilterManager;
llmOverrideManager: LlmOverrideManager;
selectedAssistant: Persona;
@@ -597,24 +603,38 @@ export function ChatInputBar({
}}
/>
</div>
<div className="absolute bottom-2.5 mobile:right-4 desktop:right-10">
<div
className="cursor-pointer"
onClick={() => {
if (message) {
onSubmit();
}
}}
>
<SendIcon
size={28}
className={`text-emphasis text-white p-1 rounded-full ${
message && !isStreaming
? "bg-background-800"
: "bg-[#D7D7D7]"
}`}
/>
</div>
{chatState == "streaming" ||
chatState == "toolBuilding" ||
chatState == "loading" ? (
<button
className={`cursor-pointer ${chatState != "streaming" ? "bg-background-400" : "bg-background-800"} h-[28px] w-[28px] rounded-full`}
onClick={stopGenerating}
disabled={chatState != "streaming"}
>
<StopGeneratingIcon
size={10}
className={`text-emphasis m-auto text-white flex-none
}`}
/>
</button>
) : (
<button
className="cursor-pointer"
onClick={() => {
if (message) {
onSubmit();
}
}}
disabled={chatState != "input"}
>
<SendIcon
size={28}
className={`text-emphasis text-white p-1 rounded-full ${chatState == "input" && message ? "bg-background-800" : "bg-background-400"} `}
/>
</button>
)}
</div>
</div>
</div>

View File

@@ -49,12 +49,6 @@ export interface ToolCallMetadata {
tool_result?: Record<string, any>;
}
export interface ToolCallFinalResult {
tool_name: string;
tool_args: Record<string, any>;
tool_result: Record<string, any>;
}
export interface ChatSession {
id: number;
name: string;
@@ -72,6 +66,24 @@ export interface SearchSession {
description: string;
}
export interface PreviousAIMessage {
messageId?: number;
message?: string;
type?: "assistant";
retrievalType?: RetrievalType;
query?: string | null;
documents?: DanswerDocument[] | null;
citations?: CitationMap;
files?: FileDescriptor[];
toolCall?: ToolCallMetadata | null;
// for rebuilding the message tree
parentMessageId?: number | null;
childrenMessageIds?: number[];
latestChildMessageId?: number | null;
alternateAssistantID?: number | null;
}
export interface Message {
messageId: number;
message: string;
@@ -81,13 +93,15 @@ export interface Message {
documents?: DanswerDocument[] | null;
citations?: CitationMap;
files: FileDescriptor[];
toolCalls: ToolCallMetadata[];
toolCall: ToolCallMetadata | null;
// for rebuilding the message tree
parentMessageId: number | null;
childrenMessageIds?: number[];
latestChildMessageId?: number | null;
alternateAssistantID?: number | null;
stackTrace?: string | null;
alternate_model?: string;
}
export interface BackendChatSession {
@@ -114,8 +128,14 @@ export interface BackendMessage {
time_sent: string;
citations: CitationMap;
files: FileDescriptor[];
tool_calls: ToolCallFinalResult[];
tool_call: ToolCallMetadata | null;
alternate_assistant_id?: number | null;
alternate_model?: string;
}
export interface MessageResponseIDInfo {
user_message_id: number | null;
reserved_assistant_message_id: number;
}
export interface DocumentsResponse {

View File

@@ -3,8 +3,8 @@ import {
DanswerDocument,
Filters,
} from "@/lib/search/interfaces";
import { handleStream } from "@/lib/search/streamingUtils";
import { FeedbackType } from "./types";
import { handleSSEStream, handleStream } from "@/lib/search/streamingUtils";
import { ChatState, FeedbackType } from "./types";
import {
Dispatch,
MutableRefObject,
@@ -20,6 +20,7 @@ import {
FileDescriptor,
ImageGenerationDisplay,
Message,
MessageResponseIDInfo,
RetrievalType,
StreamingError,
ToolCallMetadata,
@@ -109,9 +110,11 @@ export type PacketType =
| AnswerPiecePacket
| DocumentsResponse
| ImageGenerationDisplay
| StreamingError;
| StreamingError
| MessageResponseIDInfo;
export async function* sendMessage({
regenerate,
message,
fileDescriptors,
parentMessageId,
@@ -127,7 +130,9 @@ export async function* sendMessage({
systemPromptOverride,
useExistingUserMessage,
alternateAssistantId,
signal,
}: {
regenerate: boolean;
message: string;
fileDescriptors: FileDescriptor[];
parentMessageId: number | null;
@@ -137,70 +142,70 @@ export async function* sendMessage({
selectedDocumentIds: number[] | null;
queryOverride?: string;
forceSearch?: boolean;
// LLM overrides
modelProvider?: string;
modelVersion?: string;
temperature?: number;
// prompt overrides
systemPromptOverride?: string;
// if specified, will use the existing latest user message
// and will ignore the specified `message`
useExistingUserMessage?: boolean;
alternateAssistantId?: number;
}) {
signal?: AbortSignal;
}): AsyncGenerator<PacketType, void, unknown> {
const documentsAreSelected =
selectedDocumentIds && selectedDocumentIds.length > 0;
const sendMessageResponse = await fetch("/api/chat/send-message", {
const body = JSON.stringify({
alternate_assistant_id: alternateAssistantId,
chat_session_id: chatSessionId,
parent_message_id: parentMessageId,
message: message,
prompt_id: promptId,
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
file_descriptors: fileDescriptors,
regenerate,
retrieval_options: !documentsAreSelected
? {
run_search:
promptId === null ||
promptId === undefined ||
queryOverride ||
forceSearch
? "always"
: "auto",
real_time: true,
filters: filters,
}
: null,
query_override: queryOverride,
prompt_override: systemPromptOverride
? {
system_prompt: systemPromptOverride,
}
: null,
llm_override:
temperature || modelVersion
? {
temperature,
model_provider: modelProvider,
model_version: modelVersion,
}
: null,
use_existing_user_message: useExistingUserMessage,
});
const response = await fetch(`/api/chat/send-message`, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
alternate_assistant_id: alternateAssistantId,
chat_session_id: chatSessionId,
parent_message_id: parentMessageId,
message: message,
prompt_id: promptId,
search_doc_ids: documentsAreSelected ? selectedDocumentIds : null,
file_descriptors: fileDescriptors,
retrieval_options: !documentsAreSelected
? {
run_search:
promptId === null ||
promptId === undefined ||
queryOverride ||
forceSearch
? "always"
: "auto",
real_time: true,
filters: filters,
}
: null,
query_override: queryOverride,
prompt_override: systemPromptOverride
? {
system_prompt: systemPromptOverride,
}
: null,
llm_override:
temperature || modelVersion
? {
temperature,
model_provider: modelProvider,
model_version: modelVersion,
}
: null,
use_existing_user_message: useExistingUserMessage,
}),
body,
signal,
});
if (!sendMessageResponse.ok) {
const errorJson = await sendMessageResponse.json();
const errorMsg = errorJson.message || errorJson.detail || "";
throw Error(`Failed to send message - ${errorMsg}`);
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`);
}
yield* handleStream<PacketType>(sendMessageResponse);
yield* handleSSEStream<PacketType>(response);
}
export async function nameChatSession(chatSessionId: number, message: string) {
@@ -384,13 +389,12 @@ export function getLastSuccessfulMessageId(messageHistory: Message[]) {
.reverse()
.find(
(message) =>
message.type === "assistant" &&
(message.type === "assistant" || message.type === "system") &&
message.messageId !== -1 &&
message.messageId !== null
);
return lastSuccessfulMessage ? lastSuccessfulMessage?.messageId : null;
}
export function processRawChatHistory(
rawMessages: BackendMessage[]
): Map<number, Message> {
@@ -429,10 +433,11 @@ export function processRawChatHistory(
citations: messageInfo?.citations || {},
}
: {}),
toolCalls: messageInfo.tool_calls,
toolCall: messageInfo.tool_call,
parentMessageId: messageInfo.parent_message,
childrenMessageIds: [],
latestChildMessageId: messageInfo.latest_child_message,
alternate_model: messageInfo.alternate_model,
};
messages.set(messageInfo.message_id, message);
@@ -635,14 +640,14 @@ export async function uploadFilesForChat(
}
export async function useScrollonStream({
isStreaming,
chatState,
scrollableDivRef,
scrollDist,
endDivRef,
distance,
debounce,
}: {
isStreaming: boolean;
chatState: ChatState;
scrollableDivRef: RefObject<HTMLDivElement>;
scrollDist: MutableRefObject<number>;
endDivRef: RefObject<HTMLDivElement>;
@@ -656,7 +661,7 @@ export async function useScrollonStream({
const previousScroll = useRef<number>(0);
useEffect(() => {
if (isStreaming && scrollableDivRef && scrollableDivRef.current) {
if (chatState != "input" && scrollableDivRef && scrollableDivRef.current) {
let newHeight: number = scrollableDivRef.current?.scrollTop!;
const heightDifference = newHeight - previousScroll.current;
previousScroll.current = newHeight;
@@ -712,7 +717,7 @@ export async function useScrollonStream({
// scroll on end of stream if within distance
useEffect(() => {
if (scrollableDivRef?.current && !isStreaming) {
if (scrollableDivRef?.current && chatState == "input") {
if (scrollDist.current < distance - 50) {
scrollableDivRef?.current?.scrollBy({
left: 0,
@@ -721,5 +726,5 @@ export async function useScrollonStream({
});
}
}
}, [isStreaming]);
}, [chatState]);
}

View File

@@ -17,6 +17,7 @@ import {
useState,
} from "react";
import ReactMarkdown from "react-markdown";
import Prism from "prismjs";
import {
DanswerDocument,
FilteredDanswerDocument,
@@ -44,22 +45,29 @@ import "./custom-code-styles.css";
import { Persona } from "@/app/admin/assistants/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Citation } from "@/components/search/results/Citation";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import {
DislikeFeedbackIcon,
LikeFeedbackIcon,
ThumbsUpIcon,
ThumbsDownIcon,
LikeFeedback,
DislikeFeedback,
ToolCallIcon,
} from "@/components/icons/icons";
import {
CustomTooltip,
TooltipGroup,
} from "@/components/tooltip/CustomTooltip";
import { ValidSources } from "@/lib/types";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { useMouseTracking } from "./hooks";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import DualPromptDisplay from "../tools/ImagePromptCitation";
import { Popover } from "@/components/popover/Popover";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { ValidSources } from "@/lib/types";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -110,9 +118,13 @@ function FileDisplay({
}
export const AIMessage = ({
regenerate,
alternateModel,
shared,
isActive,
toggleDocumentSelection,
hasParentAI,
hasChildAI,
alternativeAssistant,
docs,
messageId,
@@ -124,6 +136,7 @@ export const AIMessage = ({
citedDocuments,
toolCall,
isComplete,
generatingTool,
hasDocs,
handleFeedback,
isCurrentlyShowingRetrieved,
@@ -132,9 +145,16 @@ export const AIMessage = ({
handleForceSearch,
retrievalDisabled,
currentPersona,
otherMessagesCanSwitchTo,
onMessageSelection,
setPopup,
}: {
shared?: boolean;
hasChildAI?: boolean;
hasParentAI?: boolean;
isActive?: boolean;
otherMessagesCanSwitchTo?: number[];
onMessageSelection?: (messageId: number) => void;
selectedDocuments?: DanswerDocument[] | null;
toggleDocumentSelection?: () => void;
docs?: DanswerDocument[] | null;
@@ -148,6 +168,7 @@ export const AIMessage = ({
citedDocuments?: [string, DanswerDocument][] | null;
toolCall?: ToolCallMetadata;
isComplete?: boolean;
generatingTool?: boolean;
hasDocs?: boolean;
handleFeedback?: (feedbackType: FeedbackType) => void;
isCurrentlyShowingRetrieved?: boolean;
@@ -155,9 +176,13 @@ export const AIMessage = ({
handleSearchQueryEdit?: (query: string) => void;
handleForceSearch?: () => void;
retrievalDisabled?: boolean;
alternateModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
setPopup: (popupSpec: PopupSpec | null) => void;
}) => {
const toolCallGenerating = toolCall && !toolCall.tool_result;
const processContent = (content: string | JSX.Element) => {
const buildFinalContentDisplay = (content: string | JSX.Element) => {
if (typeof content !== "string") {
return content;
}
@@ -179,10 +204,41 @@ export const AIMessage = ({
}
}
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
};
const finalContent = processContent(content as string);
const indicator = !isComplete && !toolCallGenerating ? " [*]()" : "";
const tool_citation =
isComplete && toolCall?.tool_result ? ` [[${toolCall.tool_name}]]()` : "";
return content + indicator + tool_citation;
};
const finalContent = buildFinalContentDisplay(content as string);
const citationRef = useRef<HTMLDivElement | null>(null);
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
useEffect(() => {
const handleClickOutside = (event: MouseEvent) => {
if (
citationRef.current &&
!citationRef.current.contains(event.target as Node)
) {
// setIsPopoverOpen(false);
}
};
document.addEventListener("mousedown", handleClickOutside);
return () => {
document.removeEventListener("mousedown", handleClickOutside);
};
}, []);
const [isReady, setIsReady] = useState(false);
useEffect(() => {
Prism.highlightAll();
setIsReady(true);
}, []);
// const finalContent = processContent(content as string);
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
const settings = useContext(SettingsContext);
@@ -240,59 +296,39 @@ export const AIMessage = ({
});
}
const currentMessageInd = messageId
? otherMessagesCanSwitchTo?.indexOf(messageId)
: undefined;
const uniqueSources: ValidSources[] = Array.from(
new Set((docs || []).map((doc) => doc.source_type))
).slice(0, 3);
const includeMessageSwitcher =
currentMessageInd !== undefined &&
onMessageSelection &&
otherMessagesCanSwitchTo &&
otherMessagesCanSwitchTo.length > 1;
return (
<div ref={trackedElementRef} className={"py-5 px-2 lg:px-5 relative flex "}>
<div
ref={trackedElementRef}
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
>
<div
className={`mx-auto ${shared ? "w-full" : "w-[90%]"} max-w-message-max`}
>
<div className={`${!shared && "mobile:ml-4 xl:ml-8"}`}>
<div className="flex">
<AssistantIcon
size="small"
assistant={alternativeAssistant || currentPersona}
/>
{!hasParentAI ? (
<AssistantIcon
size="small"
assistant={alternativeAssistant || currentPersona}
/>
) : (
<div className="w-6" />
)}
<div className="w-full">
<div className="max-w-message-max break-words">
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) &&
danswerSearchToolEnabledForPersona && (
<>
{query !== undefined &&
handleShowRetrieved !== undefined &&
isCurrentlyShowingRetrieved !== undefined &&
!retrievalDisabled && (
<div className="my-1">
<SearchSummary
query={query}
hasDocs={hasDocs || false}
messageId={messageId}
finished={toolCall?.tool_result != undefined}
isCurrentlyShowingRetrieved={
isCurrentlyShowingRetrieved
}
handleShowRetrieved={handleShowRetrieved}
handleSearchQueryEdit={handleSearchQueryEdit}
/>
</div>
)}
{handleForceSearch &&
content &&
query === undefined &&
!hasDocs &&
!retrievalDisabled && (
<div className="my-1">
<SkippedSearch
handleForceSearch={handleForceSearch}
/>
</div>
)}
</>
)}
<div className="w-full ml-4">
<div className="max-w-message-max break-words">
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
@@ -303,19 +339,22 @@ export const AIMessage = ({
!retrievalDisabled && (
<div className="mb-1">
<SearchSummary
docs={docs}
filteredDocs={filteredDocs}
query={query}
finished={toolCall?.tool_result != undefined}
hasDocs={hasDocs || false}
messageId={messageId}
isCurrentlyShowingRetrieved={
isCurrentlyShowingRetrieved
finished={
toolCall?.tool_result != undefined ||
isComplete!
}
toggleDocumentSelection={
toggleDocumentSelection
}
handleShowRetrieved={handleShowRetrieved}
handleSearchQueryEdit={handleSearchQueryEdit}
/>
</div>
)}
{handleForceSearch &&
!hasChildAI &&
content &&
query === undefined &&
!hasDocs &&
@@ -368,9 +407,8 @@ export const AIMessage = ({
{content || files ? (
<>
<FileDisplay files={files || []} />
{typeof content === "string" ? (
<div className="overflow-x-visible w-full pr-2 max-w-[675px]">
<div className="overflow-visible w-full pr-2 max-w-[675px]">
<ReactMarkdown
key={messageId}
className="prose max-w-full"
@@ -379,7 +417,49 @@ export const AIMessage = ({
const { node, ...rest } = props;
const value = rest.children;
if (value?.toString().startsWith("*")) {
if (
value?.toString() == `[${SEARCH_TOOL_NAME}]`
) {
return <></>;
} else if (
value?.toString() ==
`[${IMAGE_GENERATION_TOOL_NAME}]`
) {
return (
<Popover
open={isPopoverOpen}
onOpenChange={() => null} // only allow closing from the icon
content={
<button
onMouseDown={() => {
setIsPopoverOpen(!isPopoverOpen);
}}
>
<ToolCallIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
</button>
}
popover={
<DualPromptDisplay
arg="Prompt"
// ref={citationRef}
setPopup={setPopup}
prompt1={
toolCall?.tool_result?.[0]
?.revised_prompt
}
prompt2={
toolCall?.tool_result?.[1]
?.revised_prompt
}
/>
}
side="top"
align="center"
/>
);
} else if (
value?.toString().startsWith("*")
) {
return (
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
);
@@ -402,7 +482,7 @@ export const AIMessage = ({
return (
<a
key={node?.position?.start?.offset}
onMouseDown={() =>
onClick={() =>
rest.href
? window.open(rest.href, "_blank")
: undefined
@@ -416,7 +496,6 @@ export const AIMessage = ({
},
code: (props) => (
<CodeBlock
className="w-full"
{...props}
content={content as string}
/>
@@ -440,82 +519,10 @@ export const AIMessage = ({
) : isComplete ? null : (
<></>
)}
{isComplete && docs && docs.length > 0 && (
<div className="mt-2 -mx-8 w-full mb-4 flex relative">
<div className="w-full">
<div className="px-8 flex gap-x-2">
{!settings?.isMobile &&
filteredDocs.length > 0 &&
filteredDocs.slice(0, 2).map((doc, ind) => (
<div
key={doc.document_id}
className={`w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 pb-2 pt-1 border-b
`}
>
<a
href={doc.link}
target="_blank"
className="text-sm flex w-full pt-1 gap-x-1.5 overflow-hidden justify-between font-semibold text-text-700"
>
<Citation link={doc.link} index={ind + 1} />
<p className="shrink truncate ellipsis break-all ">
{doc.semantic_identifier ||
doc.document_id}
</p>
<div className="ml-auto flex-none">
{doc.is_internet ? (
<InternetSearchIcon url={doc.link} />
) : (
<SourceIcon
sourceType={doc.source_type}
iconSize={18}
/>
)}
</div>
</a>
<div className="flex overscroll-x-scroll mt-.5">
<DocumentMetadataBlock document={doc} />
</div>
<div className="line-clamp-3 text-xs break-words pt-1">
{doc.blurb}
</div>
</div>
))}
<div
onClick={() => {
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={-1}
className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b"
>
<div className="text-sm flex justify-between font-semibold text-text-700">
<p className="line-clamp-1">See context</p>
<div className="flex gap-x-1">
{uniqueSources.map((sourceType, ind) => {
return (
<div key={ind} className="flex-none">
<SourceIcon
sourceType={sourceType}
iconSize={18}
/>
</div>
);
})}
</div>
</div>
<div className="line-clamp-3 text-xs break-words pt-1">
See more
</div>
</div>
</div>
</div>
</div>
)}
</div>
{handleFeedback &&
{!hasChildAI &&
handleFeedback &&
isComplete &&
(isActive ? (
<div
className={`
@@ -525,21 +532,53 @@ export const AIMessage = ({
`}
>
<TooltipGroup>
<div className="flex justify-start w-full gap-x-0.5">
{includeMessageSwitcher && (
<div className="-mx-1 mr-auto">
<MessageSwitcher
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {
onMessageSelection(
otherMessagesCanSwitchTo[
currentMessageInd - 1
]
);
}}
handleNext={() => {
onMessageSelection(
otherMessagesCanSwitchTo[
currentMessageInd + 1
]
);
}}
/>
</div>
)}
</div>
<CustomTooltip showTick line content="Copy!">
<CopyButton content={content.toString()} />
</CustomTooltip>
<CustomTooltip showTick line content="Good response!">
<HoverableIcon
icon={<LikeFeedbackIcon />}
icon={<LikeFeedback />}
onClick={() => handleFeedback("like")}
/>
</CustomTooltip>
<CustomTooltip showTick line content="Bad response!">
<HoverableIcon
icon={<DislikeFeedbackIcon />}
icon={<DislikeFeedback size={16} />}
onClick={() => handleFeedback("dislike")}
/>
</CustomTooltip>
{regenerate && (
<RegenerateOption
onHoverChange={setIsRegenerateHovered}
selectedAssistant={currentPersona!}
regenerate={regenerate}
alternateModel={alternateModel}
/>
)}
</TooltipGroup>
</div>
) : (
@@ -547,31 +586,63 @@ export const AIMessage = ({
ref={hoverElementRef}
className={`
absolute -bottom-4
invisible ${(isHovering || settings?.isMobile) && "!visible"}
opacity-0 ${(isHovering || settings?.isMobile) && "!opacity-100"}
invisible ${(isHovering || isRegenerateHovered || settings?.isMobile) && "!visible"}
opacity-0 ${(isHovering || isRegenerateHovered || settings?.isMobile) && "!opacity-100"}
translate-y-2 ${(isHovering || settings?.isMobile) && "!translate-y-0"}
transition-transform duration-300 ease-in-out
flex md:flex-row gap-x-0.5 bg-background-125/40 p-1.5 rounded-lg
`}
>
<TooltipGroup>
<div className="flex justify-start w-full gap-x-0.5">
{includeMessageSwitcher && (
<div className="-mx-1 mr-auto">
<MessageSwitcher
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {
onMessageSelection(
otherMessagesCanSwitchTo[
currentMessageInd - 1
]
);
}}
handleNext={() => {
onMessageSelection(
otherMessagesCanSwitchTo[
currentMessageInd + 1
]
);
}}
/>
</div>
)}
</div>
<CustomTooltip showTick line content="Copy!">
<CopyButton content={content.toString()} />
</CustomTooltip>
<CustomTooltip showTick line content="Good response!">
<HoverableIcon
icon={<LikeFeedbackIcon />}
icon={<LikeFeedback />}
onClick={() => handleFeedback("like")}
/>
</CustomTooltip>
<CustomTooltip showTick line content="Bad response!">
<HoverableIcon
icon={<DislikeFeedbackIcon />}
icon={<DislikeFeedback size={16} />}
onClick={() => handleFeedback("dislike")}
/>
</CustomTooltip>
{regenerate && (
<RegenerateOption
selectedAssistant={currentPersona!}
regenerate={regenerate}
alternateModel={alternateModel}
onHoverChange={setIsRegenerateHovered}
/>
)}
</TooltipGroup>
</div>
))}
@@ -623,6 +694,7 @@ export const HumanMessage = ({
onEdit,
onMessageSelection,
shared,
stopGenerating = () => null,
}: {
shared?: boolean;
content: string;
@@ -631,6 +703,7 @@ export const HumanMessage = ({
otherMessagesCanSwitchTo?: number[];
onEdit?: (editedContent: string) => void;
onMessageSelection?: (messageId: number) => void;
stopGenerating?: () => void;
}) => {
const textareaRef = useRef<HTMLTextAreaElement>(null);
@@ -677,7 +750,6 @@ export const HumanMessage = ({
<div className="xl:ml-8">
<div className="flex flex-col mr-4">
<FileDisplay alignBubble files={files || []} />
<div className="flex justify-end">
<div className="w-full ml-8 flex w-full max-w-message-max break-words">
{isEditing ? (
@@ -700,24 +772,23 @@ export const HumanMessage = ({
<textarea
ref={textareaRef}
className={`
m-0
w-full
h-auto
shrink
border-0
rounded-lg
overflow-y-hidden
bg-background-emphasis
whitespace-normal
break-word
overscroll-contain
outline-none
placeholder-gray-400
resize-none
pl-4
overflow-y-auto
pr-12
py-4`}
m-0
w-full
h-auto
shrink
border-0
rounded-lg
bg-background-emphasis
whitespace-normal
break-word
overscroll-contain
outline-none
placeholder-gray-400
resize-none
pl-4
overflow-y-auto
pr-12
py-4`}
aria-multiline
role="textarea"
value={editedContent}
@@ -857,16 +928,18 @@ export const HumanMessage = ({
<MessageSwitcher
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() =>
handlePrevious={() => {
stopGenerating();
onMessageSelection(
otherMessagesCanSwitchTo[currentMessageInd - 1]
)
}
handleNext={() =>
);
}}
handleNext={() => {
stopGenerating();
onMessageSelection(
otherMessagesCanSwitchTo[currentMessageInd + 1]
)
}
);
}}
/>
</div>
)}

View File

@@ -4,9 +4,21 @@ import {
} from "@/components/BasicClickable";
import { HoverPopup } from "@/components/HoverPopup";
import { Hoverable } from "@/components/Hoverable";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { SourceIcon } from "@/components/SourceIcon";
import { ChevronDownIcon, InfoIcon } from "@/components/icons/icons";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import { Citation } from "@/components/search/results/Citation";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { useEffect, useRef, useState } from "react";
import {
DanswerDocument,
FilteredDanswerDocument,
} from "@/lib/search/interfaces";
import { ValidSources } from "@/lib/types";
import { useContext, useEffect, useRef, useState } from "react";
import { FiCheck, FiEdit2, FiSearch, FiX } from "react-icons/fi";
import { DownChevron } from "react-select/dist/declarations/src/components/indicators";
export function ShowHideDocsButton({
messageId,
@@ -37,26 +49,35 @@ export function ShowHideDocsButton({
export function SearchSummary({
query,
hasDocs,
filteredDocs,
finished,
messageId,
isCurrentlyShowingRetrieved,
handleShowRetrieved,
docs,
toggleDocumentSelection,
handleSearchQueryEdit,
}: {
toggleDocumentSelection?: () => void;
docs?: DanswerDocument[] | null;
filteredDocs: FilteredDanswerDocument[];
finished: boolean;
query: string;
hasDocs: boolean;
messageId: number | null;
isCurrentlyShowingRetrieved: boolean;
handleShowRetrieved: (messageId: number | null) => void;
handleSearchQueryEdit?: (query: string) => void;
}) {
const [isEditing, setIsEditing] = useState(false);
const [finalQuery, setFinalQuery] = useState(query);
const [isOverflowed, setIsOverflowed] = useState(false);
const searchingForRef = useRef<HTMLDivElement>(null);
const editQueryRef = useRef<HTMLInputElement>(null);
const [isDropdownOpen, setIsDropdownOpen] = useState(false);
const searchingForRef = useRef<HTMLDivElement | null>(null);
const editQueryRef = useRef<HTMLInputElement | null>(null);
const settings = useContext(SettingsContext);
const uniqueSourceTypes = Array.from(
new Set((docs || []).map((doc) => doc.source_type))
).slice(0, 3);
const toggleDropdown = () => {
setIsDropdownOpen(!isDropdownOpen);
};
useEffect(() => {
const checkOverflow = () => {
@@ -70,7 +91,7 @@ export function SearchSummary({
};
checkOverflow();
window.addEventListener("resize", checkOverflow); // Recheck on window resize
window.addEventListener("resize", checkOverflow);
return () => window.removeEventListener("resize", checkOverflow);
}, []);
@@ -88,15 +109,30 @@ export function SearchSummary({
}, [query]);
const searchingForDisplay = (
<div className={`flex p-1 rounded ${isOverflowed && "cursor-default"}`}>
<FiSearch className="flex-none mr-2 my-auto" size={14} />
<div
className={`${!finished && "loading-text"}
!text-sm !line-clamp-1 !break-all px-0.5`}
ref={searchingForRef}
>
{finished ? "Searched" : "Searching"} for: <i> {finalQuery}</i>
</div>
<div
className={`flex my-auto items-center ${isOverflowed && "cursor-default"}`}
>
{finished ? (
<>
<div
onClick={() => {
toggleDropdown();
}}
className={` transition-colors duration-300 group-hover:text-text-toolhover cursor-pointer text-text-toolrun !line-clamp-1 !break-all pr-0.5`}
ref={searchingForRef}
>
Searched {filteredDocs.length > 0 && filteredDocs.length} document
{filteredDocs.length != 1 && "s"} for {query}
</div>
</>
) : (
<div
className={`loading-text !text-sm !line-clamp-1 !break-all px-0.5`}
ref={searchingForRef}
>
Searching for: <i> {finalQuery}</i>
</div>
)}
</div>
);
@@ -147,43 +183,127 @@ export function SearchSummary({
</div>
</div>
) : null;
const SearchBlock = ({ doc, ind }: { doc: DanswerDocument; ind: number }) => {
return (
<div
onClick={() => {
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={doc.document_id}
className={`flex items-start gap-3 px-4 py-3 text-token-text-secondary ${ind == 0 && "rounded-t-xl"} hover:bg-background-100 group relative text-sm`}
// className="w-full text-sm flex transition-all duration-500 hover:bg-background-125 bg-text-100 py-4 border-b"
>
<div className="mt-1 scale-[.9] flex-none">
{doc.is_internet ? (
<InternetSearchIcon url={doc.link} />
) : (
<SourceIcon sourceType={doc.source_type} iconSize={18} />
)}
</div>
<div className="flex flex-col">
<a
href={doc.link}
target="_blank"
className="line-clamp-1 text-text-900"
>
{/* <Citation link={doc.link} index={ind + 1} /> */}
<p className="shrink truncate ellipsis break-all ">
{doc.semantic_identifier || doc.document_id}
</p>
<p className="line-clamp-3 text-text-500 break-words">
{doc.blurb}
</p>
</a>
{/* <div className="flex overscroll-x-scroll mt-.5">
<DocumentMetadataBlock document={doc} />
</div> */}
</div>
</div>
);
};
return (
<div className="flex">
{isEditing ? (
editInput
) : (
<>
<div className="text-sm">
{isOverflowed ? (
<HoverPopup
mainContent={searchingForDisplay}
popupContent={
<div>
<b>Full query:</b>{" "}
<div className="mt-1 italic">{query}</div>
</div>
}
direction="top"
/>
) : (
searchingForDisplay
<>
<div className="flex gap-x-2 group">
{isEditing ? (
editInput
) : (
<>
<div className="my-auto text-sm">
{isOverflowed ? (
<HoverPopup
mainContent={searchingForDisplay}
popupContent={
<div>
<b>Full query:</b>{" "}
<div className="mt-1 italic">{query}</div>
</div>
}
direction="top"
/>
) : (
searchingForDisplay
)}
</div>
{handleSearchQueryEdit && (
<Tooltip delayDuration={1000} content={"Edit Search"}>
<button
className="my-auto cursor-pointer rounded"
onClick={() => {
setIsEditing(true);
}}
>
<FiEdit2 />
</button>
</Tooltip>
)}
</div>
{handleSearchQueryEdit && (
<Tooltip delayDuration={1000} content={"Edit Search"}>
<button
className="my-auto hover:bg-hover p-1.5 rounded"
<button
className="my-auto invisible group-hover:visible transition-all duration-300 hover:bg-hover rounded"
onClick={toggleDropdown}
>
<ChevronDownIcon
className={`transform transition-transform ${isDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
</>
)}
</div>
{isDropdownOpen && docs && docs.length > 0 && (
<div
className={`mt-2 -mx-8 w-full mb-4 flex relative transition-all duration-300 ${isDropdownOpen ? "opacity-100 max-h-[1000px]" : "opacity-0 max-h-0"}`}
>
<div className="w-full">
<div className="mx-8 flex rounded overflow-hidden rounded-lg border-1.5 border divide-y divider-y-1.5 divider-y-border border-border flex-col gap-x-4">
{!settings?.isMobile &&
filteredDocs.length > 0 &&
filteredDocs.map((doc, ind) => (
<SearchBlock key={ind} doc={doc} ind={ind} />
))}
<div
onClick={() => {
setIsEditing(true);
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={-1}
className="cursor-pointer w-full flex transition-all duration-500 hover:bg-background-100 py-3 border-b"
>
<FiEdit2 />
</button>
</Tooltip>
)}
</>
<div key={0} className="px-3 invisible scale-[.9] flex-none">
<SourceIcon sourceType={"file"} iconSize={18} />
</div>
<div className="text-sm flex justify-between text-text-900">
<p className="line-clamp-1">See context</p>
<div className="flex gap-x-1"></div>
</div>
</div>
</div>
</div>
</div>
)}
</div>
</>
);
}

View File

@@ -27,15 +27,12 @@ export function SkippedSearch({
handleForceSearch: () => void;
}) {
return (
<div className="flex text-sm !pt-0 p-1">
<div className="flex group-hover:text-text-900 text-text-500 text-sm !pt-0 p-1">
<div className="flex mb-auto">
<FiBook className="my-auto flex-none mr-2" size={14} />
<div className="my-auto cursor-default">
<span className="mobile:hidden">
The AI decided this query didn&apos;t need a search
</span>
<span className="desktop:hidden">No search</span>
</div>
<span className="mobile:hidden">
The AI decided this query didn&apos;t need a search
</span>
<span className="desktop:hidden">No search</span>
</div>
<div className="ml-auto my-auto" onClick={handleForceSearch}>

View File

@@ -26,6 +26,7 @@ import { CHAT_SESSION_ID_KEY, FOLDER_ID_KEY } from "@/lib/drag/constants";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { WarningCircle } from "@phosphor-icons/react";
import { CustomTooltip } from "@/components/tooltip/CustomTooltip";
import { NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH } from "@/lib/constants";
export function ChatSessionDisplay({
chatSession,
@@ -33,6 +34,7 @@ export function ChatSessionDisplay({
isSelected,
skipGradient,
closeSidebar,
stopGenerating = () => null,
showShareModal,
showDeleteModal,
}: {
@@ -43,6 +45,7 @@ export function ChatSessionDisplay({
// if not set, the gradient will still be applied and cause weirdness
skipGradient?: boolean;
closeSidebar?: () => void;
stopGenerating?: () => void;
showShareModal?: (chatSession: ChatSession) => void;
showDeleteModal?: (chatSession: ChatSession) => void;
}) {
@@ -99,6 +102,9 @@ export function ChatSessionDisplay({
className="flex my-1 group relative"
key={chatSession.id}
onClick={() => {
if (NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH) {
stopGenerating();
}
if (settings?.isMobile && closeSidebar) {
closeSidebar();
}

View File

@@ -40,6 +40,7 @@ interface HistorySidebarProps {
reset?: () => void;
showShareModal?: (chatSession: ChatSession) => void;
showDeleteModal?: (chatSession: ChatSession) => void;
stopGenerating?: () => void;
}
export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
@@ -54,6 +55,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
openedFolders,
toggleSidebar,
removeToggle,
stopGenerating = () => null,
showShareModal,
showDeleteModal,
},
@@ -179,6 +181,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
)}
<div className="border-b border-border pb-4 mx-3" />
<PagesTab
stopGenerating={stopGenerating}
newFolderId={newFolderId}
showDeleteModal={showDeleteModal}
showShareModal={showShareModal}

View File

@@ -17,10 +17,12 @@ export function PagesTab({
folders,
openedFolders,
closeSidebar,
stopGenerating,
newFolderId,
showShareModal,
showDeleteModal,
}: {
stopGenerating: () => void;
page: pageType;
existingChats?: ChatSession[];
currentChatId?: number;
@@ -125,6 +127,7 @@ export function PagesTab({
return (
<div key={`${chat.id}-${chat.name}`}>
<ChatSessionDisplay
stopGenerating={stopGenerating}
showDeleteModal={showDeleteModal}
showShareModal={showShareModal}
closeSidebar={closeSidebar}

View File

@@ -15,6 +15,7 @@ import { Persona } from "@/app/admin/assistants/interfaces";
import { useContext, useEffect, useState } from "react";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { DanswerInitializingLoader } from "@/components/DanswerInitializingLoader";
import { usePopup } from "@/components/admin/connectors/Popup";
function BackToDanswerButton() {
const router = useRouter();
@@ -44,6 +45,7 @@ export function SharedChatDisplay({
Prism.highlightAll();
setIsReady(true);
}, []);
const { popup, setPopup } = usePopup();
if (!chatSession) {
return (
<div className="min-h-full w-full">
@@ -66,6 +68,7 @@ export function SharedChatDisplay({
return (
<div className="w-full h-[100dvh] overflow-hidden">
{popup}
<div className="flex max-h-full overflow-hidden pb-[72px]">
<div className="flex w-full overflow-hidden overflow-y-scroll">
<div className="w-full h-full flex-col flex max-w-message-max mx-auto">
@@ -95,6 +98,7 @@ export function SharedChatDisplay({
} else {
return (
<AIMessage
setPopup={setPopup}
shared
currentPersona={currentPersona!}
key={message.messageId}

View File

@@ -0,0 +1,97 @@
import React, { useState, useEffect, useRef } from "react";
export default function GeneratingImage({ isCompleted = false }) {
const [progress, setProgress] = useState(0);
const progressRef = useRef(0);
const animationRef = useRef<number>();
const startTimeRef = useRef<number>(Date.now());
useEffect(() => {
let lastUpdateTime = 0;
const updateInterval = 500; // Update at most every 500ms
const animationDuration = 30000; // Animation will take 30 seconds to reach ~99%
const animate = (timestamp: number) => {
const elapsedTime = timestamp - startTimeRef.current;
// Slower logarithmic curve
const maxProgress = 99.9;
const progress =
maxProgress * (1 - Math.exp(-elapsedTime / animationDuration));
// Only update if enough time has passed since the last update
if (timestamp - lastUpdateTime > updateInterval) {
progressRef.current = progress;
setProgress(Math.round(progress * 10) / 10); // Round to 1 decimal place
lastUpdateTime = timestamp;
}
if (!isCompleted && elapsedTime < animationDuration) {
animationRef.current = requestAnimationFrame(animate);
}
};
startTimeRef.current = performance.now();
animationRef.current = requestAnimationFrame(animate);
return () => {
if (animationRef.current) {
cancelAnimationFrame(animationRef.current);
}
};
}, [isCompleted]);
useEffect(() => {
if (isCompleted) {
if (animationRef.current) {
cancelAnimationFrame(animationRef.current);
}
setProgress(100);
}
}, [isCompleted]);
return (
<div className="object-cover object-center border border-neutral-200 bg-neutral-100 items-center justify-center overflow-hidden flex rounded-lg w-96 h-96 transition-opacity duration-300 opacity-100">
<div className="m-auto relative flex">
<svg className="w-16 h-16 transform -rotate-90" viewBox="0 0 100 100">
<circle
className="text-gray-200"
strokeWidth="8"
stroke="currentColor"
fill="transparent"
r="44"
cx="50"
cy="50"
/>
<circle
className="text-gray-800 transition-all duration-300"
strokeWidth="8"
strokeDasharray={276.46}
strokeDashoffset={276.46 * (1 - progress / 100)}
strokeLinecap="round"
stroke="currentColor"
fill="transparent"
r="44"
cx="50"
cy="50"
/>
</svg>
<div className="absolute inset-0 flex items-center justify-center">
<svg
className="w-6 h-6 text-neutral-500 animate-pulse-strong"
fill="none"
viewBox="0 0 24 24"
stroke="currentColor"
>
<path
strokeLinecap="round"
strokeLinejoin="round"
strokeWidth="2"
d="M4 16l4.586-4.586a2 2 0 012.828 0L16 16m-2-2l1.586-1.586a2 2 0 012.828 0L20 14m-6-6h.01M6 20h12a2 2 0 002-2V6a2 2 0 00-2-2H6a2 2 0 00-2 2v12a2 2 0 002 2z"
/>
</svg>
</div>
</div>
</div>
);
}

View File

@@ -0,0 +1,83 @@
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { CopyIcon } from "@/components/icons/icons";
import { Divider } from "@tremor/react";
import React, { forwardRef, useState } from "react";
import { FiCheck } from "react-icons/fi";
interface PromptDisplayProps {
prompt1: string;
prompt2?: string;
arg: string;
setPopup: (popupSpec: PopupSpec | null) => void;
}
const DualPromptDisplay = forwardRef<HTMLDivElement, PromptDisplayProps>(
({ prompt1, prompt2, setPopup, arg }, ref) => {
const [copied, setCopied] = useState<number | null>(null);
const copyToClipboard = (text: string, index: number) => {
navigator.clipboard
.writeText(text)
.then(() => {
setPopup({ message: "Copied to clipboard", type: "success" });
setCopied(index);
setTimeout(() => setCopied(null), 2000); // Reset copy status after 2 seconds
})
.catch((err) => {
setPopup({ message: "Failed to copy", type: "error" });
});
};
const PromptSection = ({
copied,
prompt,
index,
}: {
copied: number | null;
prompt: string;
index: number;
}) => (
<div className="w-full p-2 rounded-lg">
<h2 className="text-lg font-semibold mb-2">
{arg} {index + 1}
</h2>
<p className="line-clamp-6 text-sm text-gray-800">{prompt}</p>
<button
onMouseDown={() => copyToClipboard(prompt, index)}
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
>
{copied != null ? (
<>
<FiCheck className="mr-2" size={16} />
Copied!
</>
) : (
<>
<CopyIcon className="mr-2" size={16} />
Copy
</>
)}
</button>
</div>
);
return (
<div className="w-[400px] bg-inverted mx-auto p-6 rounded-lg shadow-lg">
<div className="flex flex-col gap-x-4">
<PromptSection copied={copied} prompt={prompt1} index={0} />
{prompt2 && (
<>
<Divider />
<PromptSection copied={copied} prompt={prompt2} index={1} />
</>
)}
</div>
</div>
);
}
);
DualPromptDisplay.displayName = "DualPromptDisplay";
export default DualPromptDisplay;

View File

@@ -1 +1,6 @@
export type FeedbackType = "like" | "dislike";
export type ChatState = "input" | "loading" | "streaming" | "toolBuilding";
export interface RegenerationState {
regenerating: boolean;
finalMessageIndex: number;
}

View File

@@ -119,6 +119,28 @@
}
}
@keyframes custom-spin {
0% {
transform: rotate(0deg);
}
25% {
transform: rotate(180deg);
}
50% {
transform: rotate(270deg);
}
75% {
transform: rotate(315deg);
}
100% {
transform: rotate(360deg);
}
}
.generating-spin {
animation: custom-spin 2s cubic-bezier(0.4, 0.2, 0.6, 0.8) infinite;
}
.collapsible {
max-height: 300px;
transition:
@@ -223,3 +245,17 @@ code[class*="language-"] {
.code-line .token.attr-name {
color: theme("colors.token-attr-name");
}
@keyframes pulse-strong {
0%,
100% {
opacity: 0.9;
}
50% {
opacity: 0.3;
}
}
.animate-pulse-strong {
animation: pulse-strong 1.5s cubic-bezier(0.4, 0, 0.6, 1) infinite;
}

View File

@@ -320,15 +320,15 @@ export const DefaultDropdown = forwardRef<HTMLDivElement, DefaultDropdownProps>(
const Content = (
<div
className={`
flex
text-sm
bg-background
px-3
py-1.5
rounded-lg
border
border-border
cursor-pointer`}
flex
text-sm
bg-background
px-3
py-1.5
rounded-lg
border
border-border
cursor-pointer`}
>
<p className="line-clamp-1">
{selectedOption?.name ||

View File

@@ -1,4 +1,3 @@
import { IconProps } from "@tremor/react";
import { IconType } from "react-icons";
const ICON_SIZE = 15;
@@ -7,13 +6,22 @@ export const Hoverable: React.FC<{
icon: IconType;
onClick?: () => void;
size?: number;
}> = ({ icon, onClick, size = ICON_SIZE }) => {
active?: boolean;
hoverText?: string;
}> = ({ icon: Icon, active, hoverText, onClick, size = ICON_SIZE }) => {
return (
<div
className="hover:bg-hover p-1.5 rounded h-fit cursor-pointer"
className={`group relative flex items-center overflow-hidden p-1.5 h-fit rounded-md cursor-pointer transition-all duration-300 ease-in-out hover:bg-hover`}
onClick={onClick}
>
{icon({ size: size, className: "my-auto" })}
<div className="flex items-center ">
<Icon size={size} className="text-gray-600 shrink-0" />
{hoverText && (
<div className="max-w-0 leading-none whitespace-nowrap overflow-hidden transition-all duration-300 ease-in-out group-hover:max-w-xs group-hover:ml-2">
<span className="text-xs text-gray-700">{hoverText}</span>
</div>
)}
</div>
</div>
);
};

View File

@@ -9,6 +9,7 @@ import { useSortable } from "@dnd-kit/sortable";
import React, { useState } from "react";
import { FiBookmark } from "react-icons/fi";
import { MdDragIndicator } from "react-icons/md";
import { DragHandle } from "../table/DragHandle";
export const AssistantCard = ({
assistant,
@@ -107,10 +108,7 @@ export function DraggableAssistantCard(props: {
style={style}
className="overlow-y-scroll inputscroll flex items-center"
>
<div {...attributes} {...listeners} className="mr-1 cursor-grab">
<MdDragIndicator className="h-3 w-3 flex-none" />
</div>
<DragHandle {...attributes} {...listeners} />
<AssistantCard {...props} />
</div>
);

View File

@@ -710,6 +710,85 @@ export const ChevronIcon = ({
);
};
export const StarFeedback = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 24 24"
>
<path
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="1.5"
d="m12.495 18.587l4.092 2.15a1.044 1.044 0 0 0 1.514-1.106l-.783-4.552a1.045 1.045 0 0 1 .303-.929l3.31-3.226a1.043 1.043 0 0 0-.575-1.785l-4.572-.657A1.044 1.044 0 0 1 15 7.907l-2.088-4.175a1.044 1.044 0 0 0-1.88 0L8.947 7.907a1.044 1.044 0 0 1-.783.575l-4.51.657a1.044 1.044 0 0 0-.584 1.785l3.309 3.226a1.044 1.044 0 0 1 .303.93l-.783 4.55a1.044 1.044 0 0 0 1.513 1.107l4.093-2.15a1.043 1.043 0 0 1 .991 0"
/>
</svg>
);
};
export const DislikeFeedback = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 24 24"
>
<g
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="1.5"
>
<path d="M5.75 2.75H4.568c-.98 0-1.775.795-1.775 1.776v8.284c0 .98.795 1.775 1.775 1.775h1.184c.98 0 1.775-.794 1.775-1.775V4.526c0-.98-.795-1.776-1.775-1.776" />
<path d="m21.16 11.757l-1.42-7.101a2.368 2.368 0 0 0-2.367-1.906h-7.48a2.367 2.367 0 0 0-2.367 2.367v7.101a3.231 3.231 0 0 0 1.184 2.367l.982 5.918a.887.887 0 0 0 1.278.65l1.1-.543a3.551 3.551 0 0 0 1.87-4.048l-.496-1.965h5.396a2.368 2.368 0 0 0 2.32-2.84" />
</g>
</svg>
);
};
export const LikeFeedback = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 24 24"
>
<g
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="1.5"
>
<path d="M5.75 9.415H4.568c-.98 0-1.775.794-1.775 1.775v8.284c0 .98.795 1.776 1.775 1.776h1.184c.98 0 1.775-.795 1.775-1.776V11.19c0-.98-.795-1.775-1.775-1.775" />
<path d="m21.16 12.243l-1.42 7.101a2.367 2.367 0 0 1-2.367 1.906h-7.48a2.367 2.367 0 0 1-2.367-2.367v-7.101A3.231 3.231 0 0 1 8.71 9.415l.982-5.918a.888.888 0 0 1 1.278-.65l1.1.544a3.55 3.55 0 0 1 1.87 4.047l-.496 1.965h5.396a2.367 2.367 0 0 1 2.32 2.84" />
</g>
</svg>
);
};
export const CheckmarkIcon = ({
size = 16,
className = defaultTailwindCSS,
@@ -1718,6 +1797,29 @@ export const FilledLikeIcon = ({
);
};
export const StopGeneratingIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 14 14"
>
<path
fill="currentColor"
fill-rule="evenodd"
d="M1.5 0A1.5 1.5 0 0 0 0 1.5v11A1.5 1.5 0 0 0 1.5 14h11a1.5 1.5 0 0 0 1.5-1.5v-11A1.5 1.5 0 0 0 12.5 0z"
clip-rule="evenodd"
/>
</svg>
);
};
export const LikeFeedbackIcon = ({
size = 16,
className = defaultTailwindCSS,
@@ -2581,3 +2683,40 @@ export const MinusIcon = ({
</svg>
);
};
export const ToolCallIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 19 15"
fill="none"
>
<path
d="M4.42 0.75H2.8625H2.75C1.64543 0.75 0.75 1.64543 0.75 2.75V11.65C0.75 12.7546 1.64543 13.65 2.75 13.65H2.8625C2.8625 13.65 2.8625 13.65 2.8625 13.65C2.8625 13.65 4.00751 13.65 4.42 13.65M13.98 13.65H15.5375H15.65C16.7546 13.65 17.65 12.7546 17.65 11.65V2.75C17.65 1.64543 16.7546 0.75 15.65 0.75H15.5375H13.98"
stroke="currentColor"
strokeWidth="1.5"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M5.55283 4.21963C5.25993 3.92674 4.78506 3.92674 4.49217 4.21963C4.19927 4.51252 4.19927 4.9874 4.49217 5.28029L6.36184 7.14996L4.49217 9.01963C4.19927 9.31252 4.19927 9.7874 4.49217 10.0803C4.78506 10.3732 5.25993 10.3732 5.55283 10.0803L7.95283 7.68029C8.24572 7.3874 8.24572 6.91252 7.95283 6.61963L5.55283 4.21963Z"
fill="currentColor"
stroke="currentColor"
strokeWidth="0.2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M9.77753 8.75003C9.3357 8.75003 8.97753 9.10821 8.97753 9.55003C8.97753 9.99186 9.3357 10.35 9.77753 10.35H13.2775C13.7194 10.35 14.0775 9.99186 14.0775 9.55003C14.0775 9.10821 13.7194 8.75003 13.2775 8.75003H9.77753Z"
fill="currentColor"
stroke="currentColor"
strokeWidth="0.1"
/>
</svg>
);
};

View File

@@ -23,31 +23,17 @@ export function Citation({
>
<a
onMouseDown={() => (link ? window.open(link, "_blank") : undefined)}
className="cursor-pointer inline ml-1 align-middle"
className="cursor-pointer inline ml-1 font-sans align-middle inline-block text-sm text-blue-500 cursor-help leading-none inline ml-1 align-middle"
>
<span className="group relative -top-1 text-sm text-gray-500 dark:text-gray-400 selection:bg-indigo-300 selection:text-black dark:selection:bg-indigo-900 dark:selection:text-white">
<span
className="inline-flex bg-background-200 group-hover:bg-background-300 items-center justify-center h-3.5 min-w-3.5 px-1 text-center text-xs rounded-full border-1 border-gray-400 ring-1 ring-gray-400 divide-gray-300 dark:divide-gray-700 dark:ring-gray-700 dark:border-gray-700 transition duration-150"
data-number="3"
>
{innerText}
</span>
</span>
[{innerText}]
</a>
</CustomTooltip>
);
} else {
return (
<CustomTooltip content={<div>This doc doesn&apos;t have a link!</div>}>
<div className="inline-block cursor-help leading-none inline ml-1 align-middle">
<span className="group relative -top-1 text-gray-500 dark:text-gray-400 selection:bg-indigo-300 selection:text-black dark:selection:bg-indigo-900 dark:selection:text-white">
<span
className="inline-flex bg-background-200 group-hover:bg-background-300 items-center justify-center h-3.5 min-w-3.5 flex-none px-1 text-center text-xs rounded-full border-1 border-gray-400 ring-1 ring-gray-400 divide-gray-300 dark:divide-gray-700 dark:ring-gray-700 dark:border-gray-700 transition duration-150"
data-number="3"
>
{innerText}
</span>
</span>
<div className="inline-block text-sm font-sans text-blue-500 cursor-help leading-none inline ml-1 align-middle">
[{innerText}]
</div>
</CustomTooltip>
);

View File

@@ -4,9 +4,9 @@ import { MdDragIndicator } from "react-icons/md";
export const DragHandle = (props: any) => {
return (
<div
className={
props.isDragging ? "hover:cursor-grabbing" : "hover:cursor-grab"
}
className={`mobile:hidden
${props.isDragging ? "hover:cursor-grabbing" : "hover:cursor-grab"}
`}
{...props}
>
<MdDragIndicator />

View File

@@ -44,8 +44,10 @@ export const CustomTooltip = ({
wrap,
showTick = false,
delay = 500,
maxWidth,
position = "bottom",
}: {
maxWidth?: boolean;
content: string | ReactNode;
children: JSX.Element;
large?: boolean;
@@ -56,6 +58,7 @@ export const CustomTooltip = ({
wrap?: boolean;
citation?: boolean;
position?: "top" | "bottom";
className?: string;
}) => {
const [isVisible, setIsVisible] = useState(false);
const [tooltipPosition, setTooltipPosition] = useState({ top: 0, left: 0 });

View File

@@ -0,0 +1,36 @@
// For handling AI message `sequences` (ie. ai messages which are streamed in sequence as separate messags but are in reality one message)
import { Message } from "@/app/chat/interfaces";
import { DanswerDocument } from "../search/interfaces";
export function getConsecutiveAIMessagesAtEnd(
messageHistory: Message[]
): Message[] {
const aiMessages = [];
for (let i = messageHistory.length - 1; i >= 0; i--) {
if (messageHistory[i]?.type === "assistant") {
aiMessages.unshift(messageHistory[i]);
} else {
break;
}
}
return aiMessages;
}
export function getUniqueDocumentsFromAIMessages(
messages: Message[]
): DanswerDocument[] {
const uniqueDocumentsMap = new Map<string, DanswerDocument>();
messages.forEach((message) => {
if (message.documents) {
message.documents.forEach((doc) => {
const uniqueKey = `${doc.document_id}-${doc.chunk_ind}`;
if (!uniqueDocumentsMap.has(uniqueKey)) {
uniqueDocumentsMap.set(uniqueKey, doc);
}
});
}
});
return Array.from(uniqueDocumentsMap.values());
}

View File

@@ -4,6 +4,11 @@ export const HOST_URL = process.env.WEB_DOMAIN || "http://127.0.0.1:3000";
export const HEADER_HEIGHT = "h-16";
export const SUB_HEADER = "h-12";
export const SIDEBAR_WIDTH_CONST = 425;
export const MOBILE_SIDEBAR_WIDTH_CONST = 300;
export const SIDEBAR_CARD_WIDTH = `w-[400px]`;
export const MOBILE_SIDEBAR_CARD_WIDTH = `w-[275px]`;
export const INTERNAL_URL = process.env.INTERNAL_URL || "http://127.0.0.1:8080";
export const NEXT_PUBLIC_DISABLE_STREAMING =
process.env.NEXT_PUBLIC_DISABLE_STREAMING?.toLowerCase() === "true";
@@ -24,9 +29,6 @@ export const GOOGLE_DRIVE_AUTH_IS_ADMIN_COOKIE_NAME =
export const SEARCH_TYPE_COOKIE_NAME = "search_type";
export const AGENTIC_SEARCH_TYPE_COOKIE_NAME = "agentic_type";
export const SIDEBAR_WIDTH_CONST = "350px";
export const SIDEBAR_WIDTH = `w-[350px]`;
export const LOGOUT_DISABLED =
process.env.NEXT_PUBLIC_DISABLE_LOGOUT?.toLowerCase() === "true";
@@ -35,6 +37,9 @@ export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
export const TOGGLED_CONNECTORS_COOKIE_NAME = "toggled_connectors";
export const NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH =
process.env.NEXT_PUBLIC_STOP_GENERATING_ON_SWITCH?.toLowerCase() === "true";
/* Enterprise-only settings */
// NOTE: this should ONLY be used on the server-side. If used client side,

View File

@@ -18,6 +18,9 @@ export type SearchType = (typeof SearchType)[keyof typeof SearchType];
export interface AnswerPiecePacket {
answer_piece: string;
}
export interface DelimiterPiece {
delimiter: boolean;
}
export interface ErrorMessagePacket {
error: string;

View File

@@ -5,6 +5,7 @@ import {
import {
AnswerPiecePacket,
DanswerDocument,
DelimiterPiece,
DocumentInfoPacket,
ErrorMessagePacket,
Quote,
@@ -91,6 +92,7 @@ export const searchRequestStreamed = async ({
| DocumentInfoPacket
| LLMRelevanceFilterPacket
| BackendMessage
| DelimiterPiece
| RelevanceChunk
>(decoder.decode(value, { stream: true }), previousPartialChunk);
if (!completedChunks.length && !partialChunk) {

View File

@@ -1,3 +1,5 @@
import { PacketType } from "@/app/chat/lib";
type NonEmptyObject = { [k: string]: any };
const processSingleChunk = <T extends NonEmptyObject>(
@@ -75,3 +77,33 @@ export async function* handleStream<T extends NonEmptyObject>(
yield await Promise.resolve(completedChunks);
}
}
export async function* handleSSEStream<T extends PacketType>(
streamingResponse: Response
): AsyncGenerator<T, void, unknown> {
const reader = streamingResponse.body?.getReader();
const decoder = new TextDecoder();
while (true) {
const rawChunk = await reader?.read();
if (!rawChunk) {
throw new Error("Unable to process chunk");
}
const { done, value } = rawChunk;
if (done) {
break;
}
const chunk = decoder.decode(value);
const lines = chunk.split("\n").filter((line) => line.trim() !== "");
for (const line of lines) {
try {
const data = JSON.parse(line) as T;
yield data;
} catch (error) {
console.error("Error parsing SSE data:", error);
}
}
}
}

View File

@@ -90,10 +90,11 @@ module.exports = {
"background-200": "#e5e5e5", // neutral-200
"background-300": "#d4d4d4", // neutral-300
"background-400": "#a3a3a3", // neutral-400
"background-500": "#737373", // neutral-400
"background-600": "#525252", // neutral-400
"background-700": "#404040", // neutral-400
"background-800": "#262626", // neutral-800
"background-400": "#a3a3a3", // neutral-500
"background-500": "#737373", // darkMedium, neutral-500
"background-600": "#525252", // dark, neutral-600
"background-700": "#404040", // solid, neutral-700
"background-800": "#262626", // solidDark, neutral-800
"background-900": "#111827", // gray-900
"background-inverted": "#000000", // black
@@ -110,6 +111,10 @@ module.exports = {
"text-600": "#525252", // dark, neutral-600
"text-700": "#404040", // solid, neutral-700
"text-800": "#262626", // solidDark, neutral-800
"text-900": "#171717", // dark, neutral-900
"text-toolrun": "#a3a3a3", // darkMedium, neutral-500
"text-toolhover": "#171717", // dark, neutral-900
subtle: "#6b7280", // gray-500
default: "#4b5563", // gray-600