Compare commits

...

36 Commits

Author SHA1 Message Date
pablodanswer
a0d3eb28e8 remove csv + add height 2024-09-18 21:04:39 -07:00
pablodanswer
d789c9ac52 update logs 2024-09-17 15:16:41 -07:00
pablodanswer
d989ce13e7 overhaul 2024-09-17 15:08:43 -07:00
pablodanswer
3170430673 updated tool runner + displays 2024-09-17 08:55:14 -07:00
pablodanswer
2beffdaa6e update/finalize 2024-09-16 20:12:51 -07:00
pablodanswer
77ee061e67 slightly cleaner loading animations 2024-09-16 18:46:16 -07:00
pablodanswer
532bc53a9a updated tooltip spacing 2024-09-16 18:10:50 -07:00
pablodanswer
7b7b95703d add csv display 2024-09-16 18:08:23 -07:00
pablodanswer
fcc5efdaf8 stash 2024-09-16 16:44:51 -07:00
pablodanswer
1ea4a53af1 udpate 2024-09-16 13:53:21 -07:00
pablodanswer
47479c8799 update 2024-09-16 12:40:23 -07:00
pablodanswer
fbc5008259 logs 2024-09-16 12:40:23 -07:00
pablodanswer
d684fb116d update constants 2024-09-16 12:40:23 -07:00
pablodanswer
2e61b374f4 remove plots 2024-09-16 12:40:23 -07:00
pablodanswer
15d324834f update csv 2024-09-16 12:40:23 -07:00
pablodanswer
de9a9b7b6e add graphs again 2024-09-16 12:40:23 -07:00
pablodanswer
47eb8c521d temp` 2024-09-16 12:40:23 -07:00
pablodanswer
875fb05dca functional search and chat once again! 2024-09-16 12:40:23 -07:00
pablodanswer
1285b2f4d4 update for typing 2024-09-16 12:39:58 -07:00
pablodanswer
842628771b minor robustification for search 2024-09-16 12:07:52 -07:00
pablodanswer
7a9d5bd92e minor updates 2024-09-16 11:45:35 -07:00
pablodanswer
4f3b513ccb minor update 2024-09-16 11:44:39 -07:00
pablodanswer
cd454dd780 update clarity 2024-09-16 11:37:24 -07:00
pablodanswer
9140ee99cb asdf 2024-09-16 11:26:57 -07:00
pablodanswer
a64f27c895 functional 2024-09-16 11:26:57 -07:00
pablodanswer
fdf5611a35 add back frozen message map:wq 2024-09-16 11:26:57 -07:00
pablodanswer
c4f483d100 update port for integration testing 2024-09-16 11:26:57 -07:00
pablodanswer
fc28c6b9e1 fix stubborn typing issue 2024-09-16 11:26:57 -07:00
pablodanswer
33e25dbd8b clean up logs / build issues 2024-09-16 11:26:57 -07:00
pablodanswer
659e8cb69e validated + build-ready 2024-09-16 11:26:57 -07:00
pablodanswer
681175e9c3 add edits and so on 2024-09-16 11:26:57 -07:00
pablodanswer
de18ec7ea4 functional ux standing till 2024-09-16 11:26:57 -07:00
pablodanswer
9edbb0806d add back image citations 2024-09-16 11:26:57 -07:00
pablodanswer
63d10e7482 functional search and chat once again! 2024-09-16 11:26:57 -07:00
pablodanswer
ff6a15b5af squash 2024-09-16 11:26:57 -07:00
pablodanswer
49397e8a86 add sequential tool calls 2024-09-16 11:26:57 -07:00
77 changed files with 4338 additions and 718 deletions

BIN
backend/aaa garp.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 27 KiB

View File

@@ -0,0 +1,65 @@
"""single tool call per message
Revision ID: 4e8e7ae58189
Revises: 5c7fdadae813
Create Date: 2024-09-09 10:07:58.008838
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4e8e7ae58189"
down_revision = "5c7fdadae813"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create the new column
op.add_column(
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
)
op.create_foreign_key(
"fk_chat_message_tool_call",
"chat_message",
"tool_call",
["tool_call_id"],
["id"],
)
# Migrate existing data
op.execute(
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
)
# Drop the old relationship
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
op.drop_column("tool_call", "message_id")
# Add a unique constraint to ensure one-to-one relationship
op.create_unique_constraint(
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
)
def downgrade() -> None:
# Add back the old column
op.add_column(
"tool_call",
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
)
op.create_foreign_key(
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
)
# Migrate data back
op.execute(
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
)
# Drop the new column
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
op.drop_column("chat_message", "tool_call_id")

View File

@@ -11,6 +11,7 @@ from danswer.search.enums import SearchType
from danswer.search.models import RetrievalDocs
from danswer.search.models import SearchResponse
from danswer.tools.custom.base_tool_types import ToolResultType
from danswer.tools.graphing.models import GraphGenerationDisplay
class LlmDoc(BaseModel):
@@ -48,6 +49,8 @@ class QADocsResponse(RetrievalDocs):
class StreamStopReason(Enum):
CONTEXT_LENGTH = "context_length"
CANCELLED = "cancelled"
FINISHED = "finished"
NEW_RESPONSE = "new_response"
class StreamStopInfo(BaseModel):
@@ -173,6 +176,7 @@ AnswerQuestionPossibleReturn = (
| ImageGenerationDisplay
| CustomToolResponse
| StreamingError
| GraphGenerationDisplay
| StreamStopInfo
)

View File

@@ -18,6 +18,8 @@ from danswer.chat.models import MessageResponseIDInfo
from danswer.chat.models import MessageSpecificCitations
from danswer.chat.models import QADocsResponse
from danswer.chat.models import StreamingError
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import BING_API_KEY
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
@@ -72,11 +74,16 @@ from danswer.search.utils import relevant_sections_to_indices
from danswer.server.query_and_chat.models import ChatMessageDetail
from danswer.server.query_and_chat.models import CreateChatMessageRequest
from danswer.server.utils import get_json_line
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
from danswer.tools.built_in_tools import get_built_in_tool_by_id
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
from danswer.tools.custom.custom_tool import CustomToolCallSummary
from danswer.tools.force import ForceUseTool
from danswer.tools.graphing.graphing_tool import GraphingResponse
from danswer.tools.graphing.graphing_tool import GraphingTool
from danswer.tools.graphing.models import GraphGenerationDisplay
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
@@ -243,6 +250,7 @@ def _get_force_search_settings(
ChatPacket = (
StreamingError
| QADocsResponse
| GraphingResponse
| LLMRelevanceFilterResponse
| FinalUsedContextDocsResponse
| ChatMessageDetail
@@ -251,6 +259,7 @@ ChatPacket = (
| CitationInfo
| ImageGenerationDisplay
| CustomToolResponse
| GraphGenerationDisplay
| MessageSpecificCitations
| MessageResponseIDInfo
)
@@ -528,8 +537,21 @@ def stream_chat_message_objects(
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
for db_tool_model in persona.tools:
# handle in-code tools specially
if db_tool_model.in_code_tool_id:
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
if (
tool_cls.__name__ == CSVAnalysisTool.__name__
and not latest_query_files
):
tool_dict[db_tool_model.id] = [CSVAnalysisTool()]
if (
tool_cls.__name__ == GraphingTool.__name__
and not latest_query_files
):
tool_dict[db_tool_model.id] = [GraphingTool(output_dir="output")]
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
search_tool = SearchTool(
db_session=db_session,
@@ -600,7 +622,6 @@ def stream_chat_message_objects(
]
continue
# handle all custom tools
if db_tool_model.openapi_schema:
tool_dict[db_tool_model.id] = cast(
@@ -617,6 +638,11 @@ def stream_chat_message_objects(
tools: list[Tool] = []
for tool_list in tool_dict.values():
tools.extend(tool_list)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
# factor in tool definition size when pruning
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
@@ -662,86 +688,180 @@ def stream_chat_message_objects(
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
dropped_indices = None
tool_result = None
yielded_message_id_info = True
for packet in answer.processed_streamed_output:
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if isinstance(packet, StreamStopInfo):
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
break
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
db_citations = None
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
if reference_db_search_docs:
db_citations = _translate_citations(
citations_list=answer.citations,
db_docs=reference_db_search_docs,
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(CustomToolCallSummary, packet.response)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
# Saving Gen AI answer and responding with message info
if tool_result is None:
tool_call = None
else:
tool_call = ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments={
k: v if not isinstance(v, bytes) else v.decode("utf-8")
for k, v in tool_result.tool_args.items()
},
tool_result=tool_result.tool_result,
)
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
message=answer.llm_answer,
rephrased_query=cast(
QADocsResponse, qa_docs_response
).rephrased_query
if qa_docs_response is not None
else None,
reference_docs=reference_db_search_docs,
files=ai_message_files,
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
citations=cast(MessageSpecificCitations, db_citations).citation_map
if db_citations is not None
else None,
error=None,
tool_call=tool_call,
)
db_session.commit() # actually save user / assistant message
msg_detail_response = translate_db_message_to_chat_message_detail(
gen_ai_response_message
)
yield msg_detail_response
reserved_message_id = reserve_message_id(
db_session=db_session,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message.id
if user_message is not None
else gen_ai_response_message.id,
message_type=MessageType.ASSISTANT,
)
yielded_message_id_info = False
partial_response = partial(
create_new_chat_message,
chat_session_id=chat_session_id,
parent_message=gen_ai_response_message,
prompt_id=prompt_id,
overridden_model=overridden_model,
message_type=MessageType.ASSISTANT,
alternate_assistant_id=new_msg_req.alternate_assistant_id,
db_session=db_session,
commit=False,
)
reference_db_search_docs = None
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
if not yielded_message_id_info:
yield MessageResponseIDInfo(
user_message_id=gen_ai_response_message.id,
reserved_assistant_message_id=reserved_message_id,
)
yielded_message_id_info = True
if isinstance(packet, ToolResponse):
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
(
qa_docs_response,
reference_db_search_docs,
dropped_indices,
) = _handle_search_tool_response_summary(
packet=packet,
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
dedupe_docs=retrieval_options.dedupe_docs
if retrieval_options
else False,
)
yield qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if reference_db_search_docs is not None:
llm_indices = relevant_sections_to_indices(
relevance_sections=relevance_sections,
items=[
translate_db_search_doc_to_server_search_doc(doc)
for doc in reference_db_search_docs
],
)
if dropped_indices:
llm_indices = drop_llm_indices(
llm_indices=llm_indices,
search_docs=reference_db_search_docs,
dropped_indices=dropped_indices,
)
yield LLMRelevanceFilterResponse(
llm_selected_doc_indices=llm_indices
)
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
yield FinalUsedContextDocsResponse(
final_context_docs=packet.response
)
elif packet.id == GRAPHING_RESPONSE_ID:
graph_generation = cast(GraphingResponse, packet.response)
yield graph_generation
# yield GraphGenerationDisplay(
# file_id=graph_generation.extra_graph_display.file_id,
# line_graph=graph_generation.extra_graph_display.line_graph,
# )
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], packet.response
)
file_ids = save_files_from_urls(
[img.url for img in img_generation_response]
)
ai_message_files = [
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
for file_id in file_ids
]
yield ImageGenerationDisplay(
file_ids=[str(file_id) for file_id in file_ids]
)
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
(
qa_docs_response,
reference_db_search_docs,
) = _handle_internet_search_tool_response_summary(
packet=packet,
db_session=db_session,
)
yield qa_docs_response
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
custom_tool_response = cast(
CustomToolCallSummary, packet.response
)
yield CustomToolResponse(
response=custom_tool_response.tool_result,
tool_name=custom_tool_response.tool_name,
)
else:
if isinstance(packet, ToolCallFinalResult):
tool_result = packet
yield cast(ChatPacket, packet)
logger.debug("Reached end of stream")
except Exception as e:
error_msg = str(e)
@@ -767,11 +887,8 @@ def stream_chat_message_objects(
)
yield AllCitations(citations=answer.citations)
# Saving Gen AI answer and responding with message info
tool_name_to_tool_id: dict[str, int] = {}
for tool_id, tool_list in tool_dict.items():
for tool in tool_list:
tool_name_to_tool_id[tool.name] = tool_id
if answer.llm_answer == "":
return
gen_ai_response_message = partial_response(
reserved_message_id=reserved_message_id,
@@ -786,16 +903,14 @@ def stream_chat_message_objects(
if message_specific_citations
else None,
error=None,
tool_calls=[
ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
]
tool_call=ToolCall(
tool_id=tool_name_to_tool_id[tool_result.tool_name],
tool_name=tool_result.tool_name,
tool_arguments=tool_result.tool_args,
tool_result=tool_result.tool_result,
)
if tool_result
else [],
else None,
)
logger.debug("Committing messages")

View File

@@ -135,7 +135,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
os.environ.get("POSTGRES_PASSWORD") or "password"
)
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
# defaults to False

View File

@@ -165,6 +165,7 @@ class FileOrigin(str, Enum):
CONNECTOR = "connector"
GENERATED_REPORT = "generated_report"
OTHER = "other"
GRAPH_GEN = "graph_gen"
class PostgresAdvisoryLocks(Enum):

View File

@@ -178,8 +178,14 @@ def delete_search_doc_message_relationship(
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
db_session.execute(stmt)
chat_message = (
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
)
if chat_message and chat_message.tool_call_id:
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
db_session.execute(stmt)
chat_message.tool_call_id = None
db_session.commit()
@@ -388,7 +394,7 @@ def get_chat_messages_by_session(
)
if prefetch_tool_calls:
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
stmt = stmt.options(joinedload(ChatMessage.tool_call))
result = db_session.scalars(stmt).unique().all()
else:
result = db_session.scalars(stmt).all()
@@ -474,7 +480,7 @@ def create_new_chat_message(
alternate_assistant_id: int | None = None,
# Maps the citation number [n] to the DB SearchDoc
citations: dict[int, int] | None = None,
tool_calls: list[ToolCall] | None = None,
tool_call: ToolCall | None = None,
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
@@ -494,7 +500,7 @@ def create_new_chat_message(
existing_message.message_type = message_type
existing_message.citations = citations
existing_message.files = files
existing_message.tool_calls = tool_calls if tool_calls else []
existing_message.tool_call = tool_call if tool_call else None
existing_message.error = error
existing_message.alternate_assistant_id = alternate_assistant_id
existing_message.overridden_model = overridden_model
@@ -513,7 +519,7 @@ def create_new_chat_message(
message_type=message_type,
citations=citations,
files=files,
tool_calls=tool_calls if tool_calls else [],
tool_call=tool_call if tool_call else None,
error=error,
alternate_assistant_id=alternate_assistant_id,
overridden_model=overridden_model,
@@ -747,14 +753,13 @@ def translate_db_message_to_chat_message_detail(
time_sent=chat_message.time_sent,
citations=chat_message.citations,
files=chat_message.files or [],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
alternate_assistant_id=chat_message.alternate_assistant_id,
overridden_model=chat_message.overridden_model,
)

View File

@@ -854,10 +854,8 @@ class ToolCall(Base):
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
message: Mapped["ChatMessage"] = relationship(
"ChatMessage", back_populates="tool_calls"
"ChatMessage", back_populates="tool_call"
)
@@ -984,9 +982,14 @@ class ChatMessage(Base):
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_calls: Mapped[list["ToolCall"]] = relationship(
"ToolCall",
back_populates="message",
tool_call_id: Mapped[int | None] = mapped_column(
ForeignKey("tool_call.id"), nullable=True
)
# NOTE: Should always be attached to the `assistant` message.
# represents the tool calls used to generate this message
tool_call: Mapped["ToolCall"] = relationship(
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
)
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
"StandardAnswer",

View File

@@ -13,6 +13,8 @@ class ChatFileType(str, Enum):
DOC = "document"
# Plain text only contain the text
PLAIN_TEXT = "plain_text"
# csv types contain the binary data
CSV = "csv"
class FileDescriptor(TypedDict):

View File

@@ -1,3 +1,4 @@
import base64
from collections.abc import Callable
from io import BytesIO
from typing import Any
@@ -16,6 +17,27 @@ from danswer.file_store.models import InMemoryChatFile
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
def save_base64_image(base64_image: str) -> str:
with get_session_context_manager() as db_session:
if base64_image.startswith("data:image"):
base64_image = base64_image.split(",", 1)[1]
image_data = base64.b64decode(base64_image)
unique_id = str(uuid4())
file_io = BytesIO(image_data)
file_store = get_default_file_store(db_session)
file_store.save_file(
file_name=unique_id,
content=file_io,
display_name="GeneratedImage",
file_origin=FileOrigin.CHAT_IMAGE_GEN,
file_type="image/png",
)
return unique_id
def load_chat_file(
file_descriptor: FileDescriptor, db_session: Session
) -> InMemoryChatFile:

View File

@@ -16,6 +16,7 @@ from danswer.chat.models import LlmDoc
from danswer.chat.models import StreamStopInfo
from danswer.chat.models import StreamStopReason
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
from danswer.configs.constants import MessageType
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import AnswerStyleConfig
from danswer.llm.answering.models import PreviousMessage
@@ -40,11 +41,13 @@ from danswer.llm.answering.stream_processing.utils import map_document_id_order
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.natural_language_processing.utils import get_tokenizer
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
from danswer.tools.custom.custom_tool_prompt_builder import (
build_user_message_for_custom_tool_for_non_tool_calling_llm,
)
from danswer.tools.force import filter_tools_for_force_tool_use
from danswer.tools.force import ForceUseTool
from danswer.tools.graphing.graphing_tool import GraphingTool
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
from danswer.tools.images.image_generation_tool import ImageGenerationTool
@@ -68,6 +71,7 @@ from danswer.tools.tool_runner import ToolRunner
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
from danswer.tools.utils import explicit_tool_calling_supported
from danswer.utils.logger import setup_logger
from shared_configs.configs import MAX_TOOL_CALLS
logger = setup_logger()
@@ -161,6 +165,10 @@ class Answer:
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
self._is_cancelled = False
self.final_context_docs: list = []
self.current_streamed_output: list = []
self.processing_stream: list = []
def _update_prompt_builder_for_search_tool(
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
) -> None:
@@ -196,41 +204,50 @@ class Answer:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_calls = 0
initiated = False
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
# / need to generate the args
tool_call_chunk = AIMessageChunk(
content="",
)
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
# if tool calling is supported, first try the raw message
# to see if we don't need to use any tools
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
if initiated:
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
initiated = True
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
tool_call_chunk: AIMessageChunk | None = None
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
tool_call_chunk = AIMessageChunk(content="")
tool_call_chunk.tool_calls = [
{
"name": self.force_use_tool.tool_name,
"args": self.force_use_tool.args,
"id": str(uuid4()),
}
]
else:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question,
self.prompt_config,
self.latest_query_files,
tool_calls,
)
)
prompt = prompt_builder.build()
final_tool_definitions = [
tool.tool_definition()
for tool in filter_tools_for_force_tool_use(
self.tools, self.force_use_tool
)
]
print(final_tool_definitions)
for message in self.llm.stream(
prompt=prompt,
tools=final_tool_definitions if final_tool_definitions else None,
@@ -257,67 +274,129 @@ class Answer:
)
if not tool_call_chunk:
return # no tool call needed
logger.info("Skipped tool call but generated message")
return
# if we have a tool call, we need to call the tool
tool_call_requests = tool_call_chunk.tool_calls
for tool_call_request in tool_call_requests:
known_tools_by_name = [
tool for tool in self.tools if tool.name == tool_call_request["name"]
]
tool_call_requests = tool_call_chunk.tool_calls
print(tool_call_requests)
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
yield from tool_runner.tool_responses()
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
for tool_call_request in tool_call_requests:
tool_calls += 1
known_tools_by_name = [
tool
for tool in self.tools
if tool.name == tool_call_request["name"]
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
if not known_tools_by_name:
logger.error(
"Tool call requested with unknown name field. \n"
f"self.tools: {self.tools}"
f"tool_call_request: {tool_call_request}"
)
if self.tools:
tool = self.tools[0]
else:
continue
else:
tool = known_tools_by_name[0]
tool_args = (
self.force_use_tool.args
if self.force_use_tool.tool_name == tool.name
and self.force_use_tool.args
else tool_call_request["args"]
)
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
yield tool_kickoff
yield from tool_runner.tool_responses()
tool_responses = []
for response in tool_runner.tool_responses():
tool_responses.append(response)
yield response
tool_call_summary = ToolCallSummary(
tool_call_request=tool_call_chunk,
tool_call_result=build_tool_message(
tool_call_request, tool_runner.tool_message_content()
),
)
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
self._update_prompt_builder_for_search_tool(prompt_builder, [])
elif tool.name == ImageGenerationTool._NAME:
img_urls = [
img_generation_result["url"]
for img_generation_result in tool_runner.tool_final_result().tool_result
]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question, img_urls=img_urls
)
)
yield tool_runner.tool_final_result()
# Update message history with tool call and response
self.message_history.append(
PreviousMessage(
message=self.question,
message_type=MessageType.USER,
token_count=len(self.llm_tokenizer.encode(self.question)),
tool_call=None,
files=[],
)
)
yield tool_runner.tool_final_result()
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
self.message_history.append(
PreviousMessage(
message=str(tool_call_request),
message_type=MessageType.ASSISTANT,
token_count=len(
self.llm_tokenizer.encode(str(tool_call_request))
),
tool_call=None,
files=[],
)
)
self.message_history.append(
PreviousMessage(
message="\n".join(str(response) for response in tool_responses),
message_type=MessageType.SYSTEM,
token_count=len(
self.llm_tokenizer.encode(
"\n".join(str(response) for response in tool_responses)
)
),
tool_call=None,
files=[],
)
)
yield from self._process_llm_stream(
prompt=prompt,
tools=[tool.tool_definition() for tool in self.tools],
)
# Generate response based on updated message history
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
return
response_content = ""
for content in self._process_llm_stream(prompt=prompt, tools=None):
if isinstance(content, str):
response_content += content
yield content
# Update message history with LLM response
self.message_history.append(
PreviousMessage(
message=response_content,
message_type=MessageType.ASSISTANT,
token_count=len(self.llm_tokenizer.encode(response_content)),
tool_call=None,
files=[],
)
)
# This method processes the LLM stream and yields the content or stop information
def _process_llm_stream(
@@ -346,139 +425,234 @@ class Answer:
) -> Iterator[
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
]:
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
tool_calls = 0
initiated = False
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
if initiated:
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
)
if not tool:
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
initiated = True
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
chosen_tool_and_args: tuple[Tool, dict] | None = None
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
if self.force_use_tool.force_use:
# if we are forcing a tool, we don't need to check which tools to run
tool = next(
iter(
[
tool
for tool in self.tools
if tool.name == self.force_use_tool.tool_name
]
),
None,
)
if not tool:
raise RuntimeError(
f"Tool '{self.force_use_tool.tool_name}' not found"
)
tool_args = (
self.force_use_tool.args
if self.force_use_tool.args is not None
else tool.get_args_for_non_tool_calling_llm(
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
force_run=True,
)
)
if tool_args is None:
raise RuntimeError(f"Tool '{tool.name}' did not return args")
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
chosen_tool_and_args = (tool, tool_args)
else:
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
tools=self.tools,
query=self.question,
history=self.message_history,
llm=self.llm,
)
available_tools_and_args = [
(self.tools[ind], args)
for ind, args in enumerate(tool_options)
if args is not None
]
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
logger.info(
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
chosen_tool_and_args = (
select_single_tool_for_non_tool_calling_llm(
tools_and_args=available_tools_and_args,
history=self.message_history,
query=self.question,
llm=self.llm,
)
if available_tools_and_args
else None
)
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
if not chosen_tool_and_args:
prompt_builder.update_system_prompt(
default_build_system_message(self.prompt_config)
)
prompt_builder.update_user_prompt(
default_build_user_message(
self.question, self.prompt_config, self.latest_query_files
)
)
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
tool_calls += 1
tool, tool_args = chosen_tool_and_args
print("tool args")
print(tool_args)
tool_runner = ToolRunner(tool, tool_args, self.llm)
yield tool_runner.kickoff()
tool_responses = []
file_name = tool_runner.args["filename"]
print(f"file ame is {file_name}")
csv_file = None
for message in self.message_history:
if message.files:
csv_file = next(
(file for file in message.files if file.filename == file_name),
None,
)
if csv_file:
break
print(self.latest_query_files)
if csv_file is None:
raise ValueError(
f"CSV file with name '{file_name}' not found in latest query files."
)
print("csv file found")
tool_runner.args["filename"] = csv_file.content
for response in tool_runner.tool_responses():
tool_responses.append(response)
yield response
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name == ImageGenerationTool._NAME:
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
# img_urls = [img.url for img in img_generation_response]
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=[img.url for img in img_generation_response],
)
)
yield response
elif tool.name == CSVAnalysisTool._NAME:
for response in tool_runner.tool_responses():
yield response
elif tool.name == GraphingTool._NAME:
for response in tool_runner.tool_responses():
print("RESOS")
print(response)
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
# img_urls=img_urls,
)
)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final_result = tool_runner.tool_final_result()
yield final_result
# Update message history
self.message_history.extend(
[
PreviousMessage(
message=str(self.question),
message_type=MessageType.USER,
token_count=len(self.llm_tokenizer.encode(str(self.question))),
tool_call=None,
files=[],
),
PreviousMessage(
message=f"Tool used: {tool.name}",
message_type=MessageType.ASSISTANT,
token_count=len(
self.llm_tokenizer.encode(f"Tool used: {tool.name}")
),
tool_call=None,
files=[],
),
PreviousMessage(
message=str(final_result),
message_type=MessageType.SYSTEM,
token_count=len(self.llm_tokenizer.encode(str(final_result))),
tool_call=None,
files=[],
),
]
)
# Generate response based on updated message history
prompt = prompt_builder.build()
yield from self._process_llm_stream(
prompt=prompt,
tools=None,
)
return
tool, tool_args = chosen_tool_and_args
tool_runner = ToolRunner(tool, tool_args)
yield tool_runner.kickoff()
response_content = ""
for content in self._process_llm_stream(prompt=prompt, tools=None):
if isinstance(content, str):
response_content += content
yield content
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
final_context_documents = None
for response in tool_runner.tool_responses():
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_context_documents = cast(list[LlmDoc], response.response)
yield response
if final_context_documents is None:
raise RuntimeError(
f"{tool.name} did not return final context documents"
)
self._update_prompt_builder_for_search_tool(
prompt_builder, final_context_documents
)
elif tool.name == ImageGenerationTool._NAME:
img_urls = []
for response in tool_runner.tool_responses():
if response.id == IMAGE_GENERATION_RESPONSE_ID:
img_generation_response = cast(
list[ImageGenerationResponse], response.response
)
img_urls = [img.url for img in img_generation_response]
yield response
prompt_builder.update_user_prompt(
build_image_generation_user_prompt(
query=self.question,
img_urls=img_urls,
# Update message history with LLM response
self.message_history.append(
PreviousMessage(
message=response_content,
message_type=MessageType.ASSISTANT,
token_count=len(self.llm_tokenizer.encode(response_content)),
tool_call=None,
files=[],
)
)
else:
prompt_builder.update_user_prompt(
HumanMessage(
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
self.question,
tool.name,
*tool_runner.tool_responses(),
)
)
)
final = tool_runner.tool_final_result()
yield final
prompt = prompt_builder.build()
yield from self._process_llm_stream(prompt=prompt, tools=None)
@property
def processed_streamed_output(self) -> AnswerStream:
@@ -495,6 +669,8 @@ class Answer:
else self._raw_output_for_non_explicit_tool_calling_llms()
)
self.processing_stream = []
def _process_stream(
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
) -> AnswerStream:
@@ -535,56 +711,70 @@ class Answer:
yield message
else:
# assumes all tool responses will come first, then the final answer
break
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
if not self.skip_gen_ai_answer_generation:
process_answer_stream_fn = _get_answer_stream_processor(
context_docs=final_context_docs or [],
# if doc selection is enabled, then search_results will be None,
# so we need to use the final_context_docs
doc_id_to_rank_map=map_document_id_order(
search_results or final_context_docs or []
),
answer_style_configs=self.answer_style_config,
)
stream_stop_info = None
new_kickoff = None
stream_stop_info = None
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
nonlocal new_kickoff
def _stream() -> Iterator[str]:
nonlocal stream_stop_info
yield cast(str, message)
for item in stream:
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
yield cast(str, item)
yield cast(str, message)
for item in stream:
if isinstance(item, StreamStopInfo):
stream_stop_info = item
return
if isinstance(item, ToolCallKickoff):
new_kickoff = item
return
else:
yield cast(str, item)
yield from process_answer_stream_fn(_stream())
yield from process_answer_stream_fn(_stream())
if stream_stop_info:
yield stream_stop_info
if stream_stop_info:
yield stream_stop_info
# handle new tool call (continuation of message)
if new_kickoff:
yield new_kickoff
processed_stream = []
for processed_packet in _process_stream(output_generator):
processed_stream.append(processed_packet)
yield processed_packet
if (
isinstance(processed_packet, StreamStopInfo)
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
):
self.current_streamed_output = self.processing_stream
self.processing_stream = []
self._processed_stream = processed_stream
self.processing_stream.append(processed_packet)
yield processed_packet
self.current_streamed_output = self.processing_stream
self._processed_stream = self.processing_stream
@property
def llm_answer(self) -> str:
answer = ""
for packet in self.processed_streamed_output:
if not self._processed_stream and not self.current_streamed_output:
return ""
for packet in self.current_streamed_output or self._processed_stream or []:
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
return answer
@property
def citations(self) -> list[CitationInfo]:
citations: list[CitationInfo] = []
for packet in self.processed_streamed_output:
for packet in self.current_streamed_output:
if isinstance(packet, CitationInfo):
citations.append(packet)
@@ -599,5 +789,4 @@ class Answer:
if not self.is_connected():
logger.debug("Answer stream has been cancelled")
self._is_cancelled = not self.is_connected()
return self._is_cancelled

View File

@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
token_count: int
message_type: MessageType
files: list[InMemoryChatFile]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallFinalResult | None
@classmethod
def from_chat_message(
@@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
for file in available_files
if str(file.file_id) in message_file_ids
],
tool_calls=[
ToolCallFinalResult(
tool_name=tool_call.tool_name,
tool_args=tool_call.tool_arguments,
tool_result=tool_call.tool_result,
)
for tool_call in chat_message.tool_calls
],
tool_call=ToolCallFinalResult(
tool_name=chat_message.tool_call.tool_name,
tool_args=chat_message.tool_call.tool_arguments,
tool_result=chat_message.tool_call.tool_result,
)
if chat_message.tool_call
else None,
)
def to_langchain_msg(self) -> BaseMessage:

View File

@@ -36,7 +36,10 @@ def default_build_system_message(
def default_build_user_message(
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
user_query: str,
prompt_config: PromptConfig,
files: list[InMemoryChatFile] = [],
previous_tool_calls: int = 0,
) -> HumanMessage:
user_prompt = (
CHAT_USER_CONTEXT_FREE_PROMPT.format(
@@ -45,6 +48,12 @@ def default_build_user_message(
if prompt_config.task_prompt
else user_query
)
if previous_tool_calls > 0:
user_prompt = (
f"You have already generated the above so do not call a tool if not necessary. "
f"Remember the query is: `{user_prompt}`"
)
user_prompt = user_prompt.strip()
user_msg = HumanMessage(
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
@@ -113,25 +122,25 @@ class AnswerPromptBuilder:
final_messages_with_tokens.append(self.user_message_and_token_cnt)
if tool_call_summary:
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_request,
check_message_tokens(
tool_call_summary.tool_call_request,
self.llm_tokenizer_encode_func,
),
)
)
final_messages_with_tokens.append(
(
tool_call_summary.tool_call_result,
check_message_tokens(
tool_call_summary.tool_call_result,
self.llm_tokenizer_encode_func,
),
)
)
# if tool_call_summary:
# final_messages_with_tokens.append(
# (
# tool_call_summary.tool_call_request,
# check_message_tokens(
# tool_call_summary.tool_call_request,
# self.llm_tokenizer_encode_func,
# ),
# )
# )
# final_messages_with_tokens.append(
# (
# tool_call_summary.tool_call_result,
# check_message_tokens(
# tool_call_summary.tool_call_result,
# self.llm_tokenizer_encode_func,
# ),
# )
# )
return drop_messages_history_overflow(
final_messages_with_tokens, self.max_tokens

View File

@@ -29,6 +29,7 @@ from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LLM
from danswer.llm.interfaces import LLMConfig
from danswer.llm.interfaces import ToolChoiceOptions
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
from danswer.utils.logger import setup_logger
@@ -98,6 +99,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
@@ -124,12 +126,21 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"name": message.name,
}
elif isinstance(message, ToolMessage):
message_dict = {
"tool_call_id": message.tool_call_id,
"role": "tool",
"name": message.name or "",
"content": message.content,
}
if message.id == GRAPHING_RESPONSE_ID:
message_dict = {
"tool_call_id": message.tool_call_id,
"role": "tool",
"name": message.name or "",
"content": "a graph",
}
else:
message_dict = {
"tool_call_id": message.tool_call_id,
"role": "tool",
"name": message.name or "",
"content": "a graph",
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:

View File

@@ -112,7 +112,7 @@ def translate_danswer_msg_to_langchain(
content = build_content_with_imgs(msg.message, files)
if msg.message_type == MessageType.SYSTEM:
raise ValueError("System messages are not currently part of history")
return SystemMessage(content=content)
if msg.message_type == MessageType.ASSISTANT:
return AIMessage(content=content)
if msg.message_type == MessageType.USER:
@@ -133,6 +133,21 @@ def translate_history_to_basemessages(
return history_basemessages, history_token_counts
def _process_csv_file(file: InMemoryChatFile) -> str:
import pandas as pd
import io
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
csv_preview = df.head().to_string()
file_name_section = (
f"CSV FILE NAME: {file.filename}\n"
if file.filename
else "CSV FILE (NO NAME PROVIDED):\n"
)
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
def _build_content(
message: str,
files: list[InMemoryChatFile] | None = None,
@@ -143,16 +158,26 @@ def _build_content(
if files
else None
)
if not text_files:
csv_files = (
[file for file in files if file.file_type == ChatFileType.CSV]
if files
else None
)
if not text_files and not csv_files:
return message
final_message_with_files = "FILES:\n\n"
for file in text_files:
for file in text_files or []:
file_content = file.content.decode("utf-8")
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
final_message_with_files += (
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
)
for file in csv_files or []:
final_message_with_files += _process_csv_file(file)
final_message_with_files += message
return final_message_with_files

View File

@@ -39,6 +39,46 @@ CHAT_USER_CONTEXT_FREE_PROMPT = f"""
""".strip()
GRAPHING_QUERY_REPHRASE_GRAPH = f"""
Given the following conversation and a follow-up input, generate Python code using matplotlib to create the requested graph.
IMPORTANT: The code should be complete and executable, including data generation if necessary.
Focus on creating a clear and informative visualization based on the user's request.
Guidelines:
- Import matplotlib.pyplot as plt at the beginning of the code.
- Generate sample data if specific data is not provided in the conversation.
- Use appropriate graph types (line, bar, scatter, pie) based on the nature of the data and request.
- Include proper labeling (title, x-axis, y-axis, legend) in the graph.
- Use plt.figure() to create the figure and assign it to a variable named 'fig'.
- Do not include plt.show() at the end of the code.
{GENERAL_SEP_PAT}
Chat History:
{{chat_history}}
{GENERAL_SEP_PAT}
Follow Up Input: {{question}}
Python Code for Graph (Respond with only the Python code to generate the graph):
```python
# Your code here
```
""".strip()
GRAPHING_GET_FILE_NAME_PROMPT = f"""
Given the following conversation, a follow-up input, and a list of available CSV files,
provide the name of the CSV file to analyze.
{GENERAL_SEP_PAT}
Chat History:
{{chat_history}}
{GENERAL_SEP_PAT}
Follow Up Input: {{question}}
{GENERAL_SEP_PAT}
Available CSV Files:
{{file_list}}
{GENERAL_SEP_PAT}
CSV File Name to Analyze:
"""
# Design considerations for the below:
# - In case of uncertainty, favor yes search so place the "yes" sections near the start of the
# prompt and after the no section as well to deemphasize the no section

View File

@@ -122,6 +122,25 @@ def upload_file(
return {"file_id": file_id}
@admin_router.post("/upload-csv")
def upload_csv(
file: UploadFile,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> dict[str, str]:
file_store = get_default_file_store(db_session)
file_type = ChatFileType.CSV
file_id = str(uuid.uuid4())
file_store.save_file(
file_name=file_id,
content=file.file,
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=file.content_type or file_type.value,
)
return {"file_id": file_id}
"""Endpoints for all"""

View File

@@ -281,14 +281,17 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
def is_disconnected_sync() -> bool:
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
try:
return not future.result(timeout=0.01)
result = not future.result(timeout=0.01)
return result
except asyncio.TimeoutError:
logger.error("Asyncio timed out")
logger.error("Asyncio timed out while checking client connection")
return True
except asyncio.CancelledError:
return True
except Exception as e:
error_msg = str(e)
logger.critical(
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
f"An unexpected error occurred with the disconnect check coroutine: {error_msg}"
)
return True
@@ -517,7 +520,6 @@ def upload_files_for_chat(
image_content_types = {"image/jpeg", "image/png", "image/webp"}
text_content_types = {
"text/plain",
"text/csv",
"text/markdown",
"text/x-markdown",
"text/x-config",
@@ -527,6 +529,9 @@ def upload_files_for_chat(
"text/xml",
"application/x-yaml",
}
csv_content_types = {
"text/csv",
}
document_content_types = {
"application/pdf",
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
@@ -536,8 +541,10 @@ def upload_files_for_chat(
"application/epub+zip",
}
allowed_content_types = image_content_types.union(text_content_types).union(
document_content_types
allowed_content_types = (
image_content_types.union(text_content_types)
.union(document_content_types)
.union(csv_content_types)
)
for file in files:
@@ -545,8 +552,12 @@ def upload_files_for_chat(
if file.content_type in image_content_types:
error_detail = "Unsupported image file type. Supported image types include .jpg, .jpeg, .png, .webp."
elif file.content_type in text_content_types:
error_detail = "Unsupported text file type. Supported text types include .txt, .csv, .md, .mdx, .conf, "
error_detail = "Unsupported text file type. Supported text types include .txt, .md, .mdx, .conf, "
".log, .tsv."
elif file.content_type in csv_content_types:
error_detail = (
"Unsupported csv file type. Supported CSV types include .csv "
)
else:
error_detail = (
"Unsupported document file type. Supported document types include .pdf, .docx, .pptx, .xlsx, "
@@ -572,6 +583,8 @@ def upload_files_for_chat(
file_type = ChatFileType.IMAGE
elif file.content_type in document_content_types:
file_type = ChatFileType.DOC
elif file.content_type in csv_content_types:
file_type = ChatFileType.CSV
else:
file_type = ChatFileType.PLAIN_TEXT
@@ -584,6 +597,7 @@ def upload_files_for_chat(
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=file.content_type or file_type.value,
)
print(f"FILE TYPE IS {file_type}")
# if the file is a doc, extract text and store that so we don't need
# to re-extract it every time we send a message

View File

@@ -178,7 +178,7 @@ class ChatMessageDetail(BaseModel):
chat_session_id: int | None = None
citations: dict[int, int] | None = None
files: list[FileDescriptor]
tool_calls: list[ToolCallFinalResult]
tool_call: ToolCallFinalResult | None
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore

View File

@@ -1,19 +1,22 @@
import base64
import json
from datetime import datetime
from typing import Any
class DateTimeEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
class DateTimeAndBytesEncoder(json.JSONEncoder):
"""Custom JSON encoder that converts datetime objects to ISO format strings and bytes to base64."""
def default(self, obj: Any) -> Any:
if isinstance(obj, datetime):
return obj.isoformat()
elif isinstance(obj, bytes):
return base64.b64encode(obj).decode("utf-8")
return super().default(obj)
def get_json_line(
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeAndBytesEncoder
) -> str:
"""
Convert a dictionary to a JSON string with datetime handling, and add a newline.

View File

@@ -0,0 +1,161 @@
import json
from collections.abc import Generator
from typing import Any
import pandas as pd
from danswer.dynamic_configs.interface import JSON_ro
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
logger = setup_logger()
CSV_ANALYSIS_RESPONSE_ID = "csv_analysis_response"
YES_ANALYSIS = "Yes Analysis"
SKIP_ANALYSIS = "Skip Analysis"
ANALYSIS_TEMPLATE = f"""
Given the conversation history and a follow up query,
determine if the system should analyze a CSV file to better answer the latest user input.
Your default response is {SKIP_ANALYSIS}.
Respond "{YES_ANALYSIS}" if:
- The user is asking about the structure or content of a CSV file.
- The user explicitly requests information about a data file.
Conversation History:
{{chat_history}}
If you are at all unsure, respond with {SKIP_ANALYSIS}.
Respond with EXACTLY and ONLY "{YES_ANALYSIS}" or "{SKIP_ANALYSIS}"
Follow Up Input:
{{final_query}}
""".strip()
system_message = """
You analyze CSV files by examining their structure and content.
Your analysis should include:
1. Number of columns and their names
2. Data types of each column
3. First few rows of data
4. Basic statistics (if applicable)
Provide a concise summary of the file's content and structure.
"""
class CSVAnalysisTool(Tool):
_NAME = "analyze_csv"
_DISPLAY_NAME = "CSV Analysis Tool"
_DESCRIPTION = system_message
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to the CSV file to analyze",
},
},
"required": ["file_path"],
},
},
}
def check_if_needs_analysis(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
) -> bool:
history_str = "\n".join([f"{m.message}" for m in history])
prompt = ANALYSIS_TEMPLATE.format(
chat_history=history_str,
final_query=query,
)
use_analysis_output = llm.invoke(prompt)
logger.debug(f"Evaluated if should use CSV analysis: {use_analysis_output}")
content = use_analysis_output.content
print(content)
return YES_ANALYSIS.lower() in str(content).lower()
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
if not force_run and not self.check_if_needs_analysis(query, history, llm):
return None
return {
"prompt": query,
}
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
graph_response = next(arg for arg in args if arg.id == CSV_ANALYSIS_RESPONSE_ID)
return json.dumps(graph_response.response.dict())
def run(
self, llm: LLM | None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
if llm is not None:
logger.warning("LLM passed to CSVAnalysisTool.run() but not used")
file_path = kwargs["file_path"]
try:
# Read the first few rows of the CSV file
df = pd.read_csv(file_path, nrows=5)
# Analyze the structure and content
analysis = {
"num_columns": len(df.columns),
"column_names": df.columns.tolist(),
"data_types": df.dtypes.astype(str).tolist(),
"first_rows": df.to_dict(orient="records"),
"basic_stats": df.describe().to_dict(),
}
# Convert the analysis to JSON
analysis_json = json.dumps(analysis, indent=2)
yield ToolResponse(id=CSV_ANALYSIS_RESPONSE_ID, response=analysis_json)
except Exception as e:
error_msg = f"Error analyzing CSV file: {str(e)}"
logger.error(error_msg)
yield ToolResponse(id="ERROR", response=error_msg)
def final_result(self, *args: ToolResponse) -> JSON_ro:
try:
analysis_response = next(
arg for arg in args if arg.id == CSV_ANALYSIS_RESPONSE_ID
)
return json.loads(analysis_response.response)
except Exception as e:
return {"error": f"Unexpected error in final_result: {str(e)}"}

View File

@@ -9,6 +9,8 @@ from sqlalchemy.orm import Session
from danswer.db.models import Persona
from danswer.db.models import Tool as ToolDBModel
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
from danswer.tools.graphing.graphing_tool import GraphingTool
from danswer.tools.images.image_generation_tool import ImageGenerationTool
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
from danswer.tools.search.search_tool import SearchTool
@@ -41,6 +43,18 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
in_code_tool_id=ImageGenerationTool.__name__,
display_name=ImageGenerationTool._DISPLAY_NAME,
),
InCodeToolInfo(
cls=GraphingTool,
description=("The graphing Tool allows the assistant to make graphs. "),
in_code_tool_id=GraphingTool.__name__,
display_name=GraphingTool._DISPLAY_NAME,
),
InCodeToolInfo(
cls=CSVAnalysisTool,
description=("The CSV Tool allows the assistant to make graphs. "),
in_code_tool_id=CSVAnalysisTool.__name__,
display_name=CSVAnalysisTool._DISPLAY_NAME,
),
# don't show the InternetSearchTool as an option if BING_API_KEY is not available
*(
[

View File

@@ -144,7 +144,9 @@ class CustomTool(Tool):
"""Actual execution of the tool"""
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
def run(
self, llm: None | LLM = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:
request_body = kwargs.get(REQUEST_BODY)
path_params = {}

View File

@@ -0,0 +1,377 @@
import json
import os
import re
import traceback
import uuid
from collections.abc import Generator
from io import BytesIO
from io import StringIO
from typing import Any
from typing import cast
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from danswer.db.engine import get_session_context_manager
from danswer.db.models import FileOrigin
from danswer.dynamic_configs.interface import JSON_ro
from danswer.file_store.file_store import get_default_file_store
from danswer.file_store.models import ChatFileType
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.interfaces import LLM
from danswer.llm.utils import _process_csv_file
from danswer.prompts.chat_prompts import (
GRAPHING_GET_FILE_NAME_PROMPT,
) # You'll need to create this
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
from danswer.tools.graphing.models import GraphingError
from danswer.tools.graphing.models import GraphingResponse
from danswer.tools.graphing.models import GraphType
from danswer.tools.tool import Tool
from danswer.tools.tool import ToolResponse
from danswer.utils.logger import setup_logger
matplotlib.use("Agg") # Use non-interactive backend
logger = setup_logger()
FINAL_GRAPH_IMAGE = "final_graph_image"
YES_GRAPHING = "Yes Graphing"
SKIP_GRAPHING = "Skip Graphing"
GRAPHING_TEMPLATE = f"""
Given the conversation history and a follow up query,
determine if the system should create a graph to better answer the latest user input.
Your default response is {SKIP_GRAPHING}.
Respond "{YES_GRAPHING}" if:
- The user is asking for information that would be better represented in a graph.
- The user explicitly requests a graph or chart.
Conversation History:
{{chat_history}}
If you are at all unsure, respond with {SKIP_GRAPHING}.
Respond with EXACTLY and ONLY "{YES_GRAPHING}" or "{SKIP_GRAPHING}"
Follow Up Input:
{{final_query}}
""".strip()
system_message = """
You create Python code for graphs using matplotlib. Your code should:
Import libraries: matplotlib.pyplot as plt, numpy as np
Define data (create sample data if needed)
Create the plot:
Use fig, ax = plt.subplots(figsize=(10, 6))
Use ax methods for plotting (e.g., ax.plot(), ax.bar())
Set labels, title, legend using ax methods
Not include plt.show()
Key points:
Use 'ax' for all plotting functions
Use 'plt' only for figure creation and any necessary global settings
Provide raw Python code without formatting
"""
TabError
class GraphingTool(Tool):
_NAME = "create_graph"
_DISPLAY_NAME = "Graphing Tool"
_DESCRIPTION = system_message
def __init__(self, output_dir: str = "generated_graphs"):
self.output_dir = output_dir
try:
os.makedirs(output_dir, exist_ok=True)
except Exception as e:
logger.error(f"Error creating output directory: {e}")
@property
def name(self) -> str:
return self._NAME
@property
def description(self) -> str:
return self._DESCRIPTION
@property
def display_name(self) -> str:
return self._DISPLAY_NAME
# def tool_definition(self) -> dict:
# return {
# "type": "function",
# "function": {
# "name": self.name,
# "description": self.description,
# "parameters": {
# "type": "object",
# "properties": {
# "code": {
# "type": "string",
# "description": "Python code to generate the graph using matplotlib and seaborn",
# },
# },
# "required": ["code"],
# },
# },
# }
def tool_definition(self) -> dict:
return {
"type": "function",
"function": {
"name": self.name,
"description": self.description,
"parameters": {
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "The name of the file to analyze for graphing purposes",
},
},
"required": ["filename"],
},
},
}
def check_if_needs_graphing(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
) -> bool:
# return True
history_str = "\n".join([f"{m.message}" for m in history])
prompt = GRAPHING_TEMPLATE.format(
chat_history=history_str,
final_query=query,
)
use_graphing_output = llm.invoke(prompt)
print(use_graphing_output)
logger.debug(f"Evaluated if should use graphing: {use_graphing_output}")
content = use_graphing_output.content
return YES_GRAPHING.lower() in str(content).lower()
def get_args_for_non_tool_calling_llm(
self,
query: str,
history: list[PreviousMessage],
llm: LLM,
force_run: bool = False,
) -> dict[str, Any] | None:
if not force_run and not self.check_if_needs_graphing(query, history, llm):
return None
file_names_and_descriptions = []
for message in history:
for file in message.files:
if file.file_type == ChatFileType.CSV:
file_name = file.filename
description = _process_csv_file(file)
file_names_and_descriptions.append(f"{file_name}: {description}")
# Use the GRAPHING_GET_FILE_NAME_PROMPT to get the file name
file_list = "\n".join(file_names_and_descriptions)
prompt = GRAPHING_GET_FILE_NAME_PROMPT.format(
chat_history="\n".join([f"{m.message}" for m in history]),
question=query,
file_list=file_list,
)
file_name_response = llm.invoke(prompt)
file_name = (
file_name_response.content
if isinstance(file_name_response.content, str)
else ""
)
file_name = file_name.strip()
# Validate that the returned file name is in our list of available files
available_files = [
name.split(":")[0].strip() for name in file_names_and_descriptions
]
if file_name not in available_files:
logger.warning(f"LLM returned invalid file name: {file_name}")
file_name = available_files[0] if available_files else None
return {
"filename": file_name,
}
def build_tool_message_content(
self, *args: ToolResponse
) -> str | list[str | dict[str, Any]]:
graph_response = next(arg for arg in args if arg.id == GRAPHING_RESPONSE_ID)
return json.dumps(graph_response.response.dict())
@staticmethod
def preprocess_code(code: str) -> str:
# Extract code between triple backticks
code_match = re.search(r"```python\n(.*?)```", code, re.DOTALL)
if code_match:
return code_match.group(1).strip()
# If no code block is found, remove any explanatory text and return the rest
return re.sub(r"^.*?import", "import", code, flags=re.DOTALL).strip()
@staticmethod
def is_line_plot(ax: plt.Axes) -> bool:
return len(ax.lines) > 0
@staticmethod
def is_bar_plot(ax: plt.Axes) -> bool:
return len(ax.patches) > 0 and isinstance(ax.patches[0], plt.Rectangle)
@staticmethod
def extract_line_plot_data(ax: plt.Axes) -> dict[str, Any]:
data = []
for line in ax.lines:
line_data = {
"x": line.get_xdata(), # type: ignore
"y": line.get_ydata(), # type: ignore
"label": line.get_label(),
"color": line.get_color(),
}
data.append(line_data)
return {
"data": data,
"title": ax.get_title(),
"xlabel": ax.get_xlabel(),
"ylabel": ax.get_ylabel(),
}
@staticmethod
def extract_bar_plot_data(ax: plt.Axes) -> dict[str, Any]:
data = []
for patch in ax.patches:
bar_data = {
"x": float(patch.get_bbox().x0 + patch.get_bbox().width / 2), # type: ignore
"y": float(patch.get_bbox().height), # type: ignore
"width": float(patch.get_bbox().width), # type: ignore
"color": patch.get_facecolor(),
}
data.append(bar_data)
return {
"data": data,
"title": ax.get_title(),
"xlabel": ax.get_xlabel(),
"ylabel": ax.get_ylabel(),
"xticks": ax.get_xticks(),
"xticklabels": [label.get_text() for label in ax.get_xticklabels()],
}
def run(
self, llm: LLM | None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
if llm is None:
raise ValueError("This tool requires an LLM to run")
file_content = kwargs["filename"]
file_content = file_content.decode("utf-8") # type: ignore
csv_file = StringIO(file_content)
df = pd.read_csv(csv_file)
# Generate a summary of the CSV data
data_summary = df.describe().to_string()
columns_info = df.dtypes.to_string()
# Create a prompt for the LLM to generate the plotting code
code_generation_prompt = f"""
{system_message}
Given the following CSV data summary and user query, create Python code to generate an appropriate graph:
Data Summary:
{data_summary}
Columns:
{columns_info}
User Query:
"Graph the data"
Generate the Python code to create the graph:
"""
code_response = llm.invoke(code_generation_prompt)
code = self.preprocess_code(cast(str, code_response.content))
# Continue with the existing code to execute and process the graph
locals_dict = {"plt": plt, "matplotlib": matplotlib, "np": np, "df": df}
file_id = None
try:
exec(code, globals(), locals_dict)
fig = locals_dict.get("fig")
if fig is None:
raise ValueError("The provided code did not create a 'fig' variable")
ax = fig.gca() # type: ignore
plot_data = None
plot_type: GraphType | None = None
if self.is_line_plot(ax):
plot_data = self.extract_line_plot_data(ax)
plot_type = GraphType.LINE_GRAPH
elif self.is_bar_plot(ax):
plot_data = self.extract_bar_plot_data(ax)
plot_type = GraphType.BAR_CHART
if plot_data:
plot_data_file = os.path.join(self.output_dir, "plot_data.json")
with open(plot_data_file, "w") as f:
json.dump(plot_data, f)
with get_session_context_manager() as db_session:
file_store = get_default_file_store(db_session)
file_id = str(uuid.uuid4())
json_content = json.dumps(plot_data)
json_bytes = json_content.encode("utf-8")
file_store.save_file(
file_name=file_id,
content=BytesIO(json_bytes),
display_name="temporary",
file_origin=FileOrigin.CHAT_UPLOAD,
file_type="json",
)
buf = BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight") # type: ignore
with open("aaa garp.png", "wb") as f:
f.write(buf.getvalue())
yield ToolResponse(
id=GRAPHING_RESPONSE_ID,
response=GraphingResponse(
file_id=str(file_id),
graph_type=plot_type.value # type: ignore
if plot_type
else None, # Use .value to get the string
plot_data=plot_data, # Pass the dictionary directly, not as a JSON string
),
)
except Exception as e:
error_msg = f"Error generating graph: {str(e)}\n{traceback.format_exc()}"
logger.error(error_msg)
yield ToolResponse(id="ERROR", response=GraphingError(error=error_msg))
def final_result(self, *args: ToolResponse) -> JSON_ro:
try:
graph_response = next(arg for arg in args if arg.id == GRAPHING_RESPONSE_ID)
return graph_response.response.dict()
except Exception as e:
return {"error": f"Unexpected error in final_result: {str(e)}"}

View File

@@ -0,0 +1,27 @@
from enum import Enum
from typing import Any
from pydantic import BaseModel
class GraphGenerationDisplay(BaseModel):
file_id: str
line_graph: bool
class GraphType(str, Enum):
BAR_CHART = "bar_chart"
LINE_GRAPH = "line_graph"
class GraphingResponse(BaseModel):
file_id: str
plot_data: dict[str, Any] | None
graph_type: GraphType
class GraphingError(BaseModel):
error: str
GRAPHING_RESPONSE_ID = "graphing_response"

View File

@@ -0,0 +1,8 @@
NON_TOOL_CALLING_PROMPT = """
You have just created the attached images in response to the following query: "{{query}}".
Can you please summarize them in a sentence or two?
"""
TOOL_CALLING_PROMPT = """
Can you please summarize the two images you generate in a sentence or two?
"""

View File

@@ -223,7 +223,12 @@ class ImageGenerationTool(Tool):
"An error occurred during image generation. Please try again later."
)
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, llm: LLM | None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
if llm is not None:
logger.warning("LLM passed to ImageGenerationTool.run() but not used")
prompt = cast(str, kwargs["prompt"])
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))

View File

@@ -4,7 +4,7 @@ from danswer.llm.utils import build_content_with_imgs
IMG_GENERATION_SUMMARY_PROMPT = """
You have just created the attached images in response to the following query: "{query}".
You have just created the most recent attached images in response to the following query: "{query}".
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
"""

View File

@@ -209,7 +209,12 @@ class InternetSearchTool(Tool):
],
)
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, llm: LLM | None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
if llm is not None:
logger.warning("LLM passed to InternetSearchTool.run() but not used")
query = cast(str, kwargs["internet_search_query"])
results = self._perform_search(query)

View File

@@ -262,7 +262,12 @@ class SearchTool(Tool):
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
def run(
self, llm: LLM | None = None, **kwargs: str
) -> Generator[ToolResponse, None, None]:
if llm is not None:
logger.warning("LLM passed to ImageGenerationTool.run() but not used")
query = cast(str, kwargs["query"])
if self.selected_sections:

View File

@@ -51,7 +51,9 @@ class Tool(abc.ABC):
"""Actual execution of the tool"""
@abc.abstractmethod
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
def run(
self, llm: LLM | None = None, **kwargs: Any
) -> Generator[ToolResponse, None, None]:
raise NotImplementedError
@abc.abstractmethod

View File

@@ -1,3 +1,4 @@
import base64
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
@@ -12,9 +13,10 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
class ToolRunner:
def __init__(self, tool: Tool, args: dict[str, Any]):
def __init__(self, tool: Tool, args: dict[str, Any], llm: LLM | None = None):
self.tool = tool
self.args = args
self._llm = llm
self._tool_responses: list[ToolResponse] | None = None
@@ -22,12 +24,25 @@ class ToolRunner:
return ToolCallKickoff(tool_name=self.tool.name, tool_args=self.args)
def tool_responses(self) -> Generator[ToolResponse, None, None]:
print("i am in the tool responses function")
if self._tool_responses is not None:
print("prev")
print(self._tool_responses)
yield from self._tool_responses
return
tool_responses: list[ToolResponse] = []
for tool_response in self.tool.run(**self.args):
print("runinnig the tool")
print(self.tool.name)
for tool_response in self.tool.run(llm=self._llm, **self.args):
if isinstance(tool_response.response, bytes):
tool_response.response = base64.b64encode(
tool_response.response
).decode("utf-8")
print("tool response")
yield tool_response
tool_responses.append(tool_response)
@@ -52,4 +67,5 @@ def check_which_tools_should_run_for_non_tool_calling_llm(
(tool.get_args_for_non_tool_calling_llm, (query, history, llm))
for tool in tools
]
return run_functions_tuples_in_parallel(tool_args_list)

View File

@@ -43,6 +43,7 @@ def select_single_tool_for_non_tool_calling_llm(
llm: LLM,
) -> tuple[Tool, dict[str, Any]] | None:
if len(tools_and_args) == 1:
logger.info("Only one tool available, returning it directly")
return tools_and_args[0]
tool_list_str = "\n".join(
@@ -58,21 +59,26 @@ def select_single_tool_for_non_tool_calling_llm(
tool_list=tool_list_str, chat_history=history_str, query=query
)
output = message_to_string(llm.invoke(prompt))
logger.info(f"LLM output for tool selection: {output}")
try:
# First try to match the number
number_match = re.search(r"\d+", output)
if number_match:
tool_ind = int(number_match.group())
logger.info(f"Selected tool by index: {tool_ind}")
return tools_and_args[tool_ind]
# If that fails, try to match the tool name
for tool, args in tools_and_args:
if tool.name.lower() in output.lower():
logger.info(f"Selected tool by name: {tool.name}")
return tool, args
# If that fails, return the first tool
logger.warning("Failed to match tool by index or name, returning first tool")
return tools_and_args[0]
except Exception:
except Exception as e:
logger.error(f"Failed to select single tool for non-tool-calling LLM: {output}")
logger.exception(e)
return None

View File

@@ -0,0 +1 @@
{"data": [{"x": 0.0, "y": 91.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 1.0, "y": 46.0, "width": 0.7999999999999999, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 2.0, "y": 26.41338, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 3.0, "y": 1.0, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 4.0, "y": 23.5, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 5.0, "y": 46.0, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 6.0, "y": 68.5, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 7.0, "y": 91.0, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}], "title": "Summary Statistics of Id Column", "xlabel": "Statistic", "ylabel": "Value", "xticks": [0, 1, 2, 3, 4, 5, 6, 7], "xticklabels": ["count", "mean", "std", "min", "25%", "50%", "75%", "max"]}

View File

@@ -21,6 +21,8 @@ CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
# Tool call configs
MAX_TOOL_CALLS = 3
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512

BIN
backend/zample garp.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
backend/zample garph.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

BIN
backend/zzagraph.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 40 KiB

View File

@@ -211,6 +211,7 @@ services:
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
- NEXT_PUBLIC_DISABLE_CSV_DISPLAY=${NEXT_PUBLIC_DISABLE_CSV_DISPLAY:-}
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
# Enterprise Edition only
@@ -292,7 +293,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -302,7 +302,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432:5432"
- "5433:5432"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -154,7 +154,7 @@ services:
- POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
ports:
- "5432"
- "5433"
volumes:
- db_volume:/var/lib/postgresql/data

View File

@@ -62,6 +62,9 @@ ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
ARG NEXT_PUBLIC_DISABLE_CSV_DISPLAY
ENV NEXT_PUBLIC_DISABLE_CSV_DISPLAY=${NEXT_PUBLIC_DISABLE_CSV_DISPLAY}
RUN npx next build
# Step 2. Production image, copy all the files and run next
@@ -122,6 +125,9 @@ ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
ARG NEXT_PUBLIC_DISABLE_CSV_DISPLAY
ENV NEXT_PUBLIC_DISABLE_CSV_DISPLAY=${NEXT_PUBLIC_DISABLE_CSV_DISPLAY}
# Note: Don't expose ports here, Compose will handle that for us if necessary.
# If you want to run this without compose, specify the ports to
# expose via cli

17
web/components.json Normal file
View File

@@ -0,0 +1,17 @@
{
"$schema": "https://ui.shadcn.com/schema.json",
"style": "default",
"rsc": true,
"tsx": true,
"tailwind": {
"config": "tailwind.config.ts",
"css": "src/app/globals.css",
"baseColor": "neutral",
"cssVariables": true,
"prefix": ""
},
"aliases": {
"components": "@/components",
"utils": "@/lib/utils"
}
}

199
web/package-lock.json generated
View File

@@ -14,7 +14,9 @@
"@phosphor-icons/react": "^2.0.8",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-tooltip": "^1.0.7",
"@tanstack/react-table": "^8.19.3",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",
"@types/lodash": "^4.17.0",
@@ -24,9 +26,13 @@
"@types/react-dom": "18.0.11",
"@types/uuid": "^9.0.8",
"autoprefixer": "^10.4.14",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"formik": "^2.2.9",
"fs": "^0.0.1-security",
"js-cookie": "^3.0.5",
"lodash": "^4.17.21",
"lucide-react": "^0.416.0",
"mdast-util-find-and-replace": "^3.0.1",
"next": "^14.2.3",
"npm": "^10.8.0",
@@ -39,12 +45,15 @@
"react-loader-spinner": "^5.4.5",
"react-markdown": "^9.0.1",
"react-select": "^5.8.0",
"recharts": "^2.12.7",
"rehype-prism-plus": "^2.0.0",
"remark-gfm": "^4.0.0",
"semver": "^7.5.4",
"sharp": "^0.32.6",
"swr": "^2.1.5",
"tailwind-merge": "^2.4.0",
"tailwindcss": "^3.3.1",
"tailwindcss-animate": "^1.0.7",
"typescript": "5.0.3",
"uuid": "^9.0.1",
"yup": "^1.1.1"
@@ -1249,6 +1258,25 @@
}
}
},
"node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-slot": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz",
"integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==",
"license": "MIT",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/react-compose-refs": "1.0.1"
},
"peerDependencies": {
"@types/react": "*",
"react": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-dismissable-layer": {
"version": "1.0.5",
"resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.5.tgz",
@@ -1373,6 +1401,25 @@
}
}
},
"node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-slot": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz",
"integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==",
"license": "MIT",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/react-compose-refs": "1.0.1"
},
"peerDependencies": {
"@types/react": "*",
"react": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-popper": {
"version": "1.1.3",
"resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.1.3.tgz",
@@ -1475,10 +1522,11 @@
}
}
},
"node_modules/@radix-ui/react-slot": {
"node_modules/@radix-ui/react-primitive/node_modules/@radix-ui/react-slot": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz",
"integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==",
"license": "MIT",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/react-compose-refs": "1.0.1"
@@ -1493,6 +1541,39 @@
}
}
},
"node_modules/@radix-ui/react-slot": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.0.tgz",
"integrity": "sha512-FUCf5XMfmW4dtYl69pdS4DbxKy8nj4M7SafBgPllysxmdachynNflAdp/gCsnYWNDnge6tI9onzMp5ARYc1KNw==",
"license": "MIT",
"dependencies": {
"@radix-ui/react-compose-refs": "1.1.0"
},
"peerDependencies": {
"@types/react": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-slot/node_modules/@radix-ui/react-compose-refs": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.0.tgz",
"integrity": "sha512-b4inOtiaOnYf9KWyO3jAeeCG6FeyfY6ldiEPanbUjWd+xIk5wZeHa8yVwmrJ2vderhu/BQvzCrJI0lHd+wIiqw==",
"license": "MIT",
"peerDependencies": {
"@types/react": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-tooltip": {
"version": "1.0.7",
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.0.7.tgz",
@@ -1527,6 +1608,25 @@
}
}
},
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-slot": {
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz",
"integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==",
"license": "MIT",
"dependencies": {
"@babel/runtime": "^7.13.10",
"@radix-ui/react-compose-refs": "1.0.1"
},
"peerDependencies": {
"@types/react": "*",
"react": "^16.8 || ^17.0 || ^18.0"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-use-callback-ref": {
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.1.tgz",
@@ -1699,6 +1799,26 @@
"tailwindcss": ">=3.0.0 || insiders"
}
},
"node_modules/@tanstack/react-table": {
"version": "8.19.3",
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.19.3.tgz",
"integrity": "sha512-MtgPZc4y+cCRtU16y1vh1myuyZ2OdkWgMEBzyjYsoMWMicKZGZvcDnub3Zwb6XF2pj9iRMvm1SO1n57lS0vXLw==",
"license": "MIT",
"dependencies": {
"@tanstack/table-core": "8.19.3"
},
"engines": {
"node": ">=12"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/tannerlinsley"
},
"peerDependencies": {
"react": ">=16.8",
"react-dom": ">=16.8"
}
},
"node_modules/@tanstack/react-virtual": {
"version": "3.5.0",
"resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.5.0.tgz",
@@ -1715,6 +1835,19 @@
"react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0"
}
},
"node_modules/@tanstack/table-core": {
"version": "8.19.3",
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.19.3.tgz",
"integrity": "sha512-IqREj9ADoml9zCAouIG/5kCGoyIxPFdqdyoxis9FisXFi5vT+iYfEfLosq4xkU/iDbMcEuAj+X8dWRLvKYDNoQ==",
"license": "MIT",
"engines": {
"node": ">=12"
},
"funding": {
"type": "github",
"url": "https://github.com/sponsors/tannerlinsley"
}
},
"node_modules/@tanstack/virtual-core": {
"version": "3.5.0",
"resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.5.0.tgz",
@@ -1743,6 +1876,16 @@
"react-dom": ">=16.6.0"
}
},
"node_modules/@tremor/react/node_modules/tailwind-merge": {
"version": "1.14.0",
"resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-1.14.0.tgz",
"integrity": "sha512-3mFKyCo/MBcgyOTlrY8T7odzZFx+w+qKSMAmdFzRvqBfLlSigU6TZnlFHK0lkMwj9Bj8OYU+9yW9lmGuS0QEnQ==",
"license": "MIT",
"funding": {
"type": "github",
"url": "https://github.com/sponsors/dcastil"
}
},
"node_modules/@types/d3-array": {
"version": "3.2.1",
"resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.1.tgz",
@@ -2792,6 +2935,27 @@
"resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz",
"integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg=="
},
"node_modules/class-variance-authority": {
"version": "0.7.0",
"resolved": "https://registry.npmjs.org/class-variance-authority/-/class-variance-authority-0.7.0.tgz",
"integrity": "sha512-jFI8IQw4hczaL4ALINxqLEXQbWcNjoSkloa4IaufXCJr6QawJyw7tuRysRsrE8w2p/4gGaxKIt/hX3qz/IbD1A==",
"license": "Apache-2.0",
"dependencies": {
"clsx": "2.0.0"
},
"funding": {
"url": "https://joebell.co.uk"
}
},
"node_modules/class-variance-authority/node_modules/clsx": {
"version": "2.0.0",
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.0.0.tgz",
"integrity": "sha512-rQ1+kcj+ttHG0MKVGBUXwayCCF1oh39BF5COIpRzuCEv8Mwjv0XucrI2ExNTOn9IlLifGClWQcU9BrZORvtw6Q==",
"license": "MIT",
"engines": {
"node": ">=6"
}
},
"node_modules/client-only": {
"version": "0.0.1",
"resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz",
@@ -2801,6 +2965,7 @@
"version": "2.1.1",
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
"integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==",
"license": "MIT",
"engines": {
"node": ">=6"
}
@@ -4172,6 +4337,12 @@
"url": "https://github.com/sponsors/rawify"
}
},
"node_modules/fs": {
"version": "0.0.1-security",
"resolved": "https://registry.npmjs.org/fs/-/fs-0.0.1-security.tgz",
"integrity": "sha512-3XY9e1pP0CVEUCdj5BmfIZxRBTSDycnbqhIOGec9QYtmVH2fbLpj86CFWkrNOkt/Fvty4KZG5lTglL9j/gJ87w==",
"license": "ISC"
},
"node_modules/fs-constants": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz",
@@ -5477,6 +5648,15 @@
"node": "14 || >=16.14"
}
},
"node_modules/lucide-react": {
"version": "0.416.0",
"resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.416.0.tgz",
"integrity": "sha512-wPWxTzdss1CTz2aqcNWNlbh4YSnH9neJWP3RaeXepxpLCTW+pmu7WcT/wxJe+Q7Y7DqGOxAqakJv0pIK3431Ag==",
"license": "ISC",
"peerDependencies": {
"react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/markdown-table": {
"version": "3.0.3",
"resolved": "https://registry.npmjs.org/markdown-table/-/markdown-table-3.0.3.tgz",
@@ -9856,6 +10036,7 @@
"version": "2.12.7",
"resolved": "https://registry.npmjs.org/recharts/-/recharts-2.12.7.tgz",
"integrity": "sha512-hlLJMhPQfv4/3NBSAyq3gzGg4h2v69RJh6KU7b3pXYNNAELs9kEoXOjbkxdXpALqKBoVmVptGfLpxdaVYqjmXQ==",
"license": "MIT",
"dependencies": {
"clsx": "^2.0.0",
"eventemitter3": "^4.0.1",
@@ -10777,9 +10958,10 @@
"integrity": "sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew=="
},
"node_modules/tailwind-merge": {
"version": "1.14.0",
"resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-1.14.0.tgz",
"integrity": "sha512-3mFKyCo/MBcgyOTlrY8T7odzZFx+w+qKSMAmdFzRvqBfLlSigU6TZnlFHK0lkMwj9Bj8OYU+9yW9lmGuS0QEnQ==",
"version": "2.4.0",
"resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-2.4.0.tgz",
"integrity": "sha512-49AwoOQNKdqKPd9CViyH5wJoSKsCDjUlzL8DxuGp3P1FsGY36NJDAa18jLZcaHAUUuTj+JB8IAo8zWgBNvBF7A==",
"license": "MIT",
"funding": {
"type": "github",
"url": "https://github.com/sponsors/dcastil"
@@ -10821,6 +11003,15 @@
"node": ">=14.0.0"
}
},
"node_modules/tailwindcss-animate": {
"version": "1.0.7",
"resolved": "https://registry.npmjs.org/tailwindcss-animate/-/tailwindcss-animate-1.0.7.tgz",
"integrity": "sha512-bl6mpH3T7I3UFxuvDEXLxy/VuFxBk5bbzplh7tXI68mwMokNYd1t9qPBHlnyTwfa4JGC4zP516I1hYYtQ/vspA==",
"license": "MIT",
"peerDependencies": {
"tailwindcss": ">=3.0.0 || insiders"
}
},
"node_modules/tailwindcss/node_modules/postcss-selector-parser": {
"version": "6.0.16",
"resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.16.tgz",

View File

@@ -15,7 +15,9 @@
"@phosphor-icons/react": "^2.0.8",
"@radix-ui/react-dialog": "^1.0.5",
"@radix-ui/react-popover": "^1.0.7",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-tooltip": "^1.0.7",
"@tanstack/react-table": "^8.19.3",
"@tremor/react": "^3.9.2",
"@types/js-cookie": "^3.0.3",
"@types/lodash": "^4.17.0",
@@ -25,9 +27,13 @@
"@types/react-dom": "18.0.11",
"@types/uuid": "^9.0.8",
"autoprefixer": "^10.4.14",
"class-variance-authority": "^0.7.0",
"clsx": "^2.1.1",
"formik": "^2.2.9",
"fs": "^0.0.1-security",
"js-cookie": "^3.0.5",
"lodash": "^4.17.21",
"lucide-react": "^0.416.0",
"mdast-util-find-and-replace": "^3.0.1",
"next": "^14.2.3",
"npm": "^10.8.0",
@@ -40,12 +46,15 @@
"react-loader-spinner": "^5.4.5",
"react-markdown": "^9.0.1",
"react-select": "^5.8.0",
"recharts": "^2.12.7",
"rehype-prism-plus": "^2.0.0",
"remark-gfm": "^4.0.0",
"semver": "^7.5.4",
"sharp": "^0.32.6",
"swr": "^2.1.5",
"tailwind-merge": "^2.4.0",
"tailwindcss": "^3.3.1",
"tailwindcss-animate": "^1.0.7",
"typescript": "5.0.3",
"uuid": "^9.0.1",
"yup": "^1.1.1"

View File

@@ -154,7 +154,6 @@ export default function AddConnector({
initialValues={createConnectorInitialValues(connector)}
validationSchema={createConnectorValidationSchema(connector)}
onSubmit={async (values) => {
console.log(" Iam submiing the connector");
const {
name,
groups,

View File

@@ -0,0 +1,44 @@
// This module handles AI message sequences - consecutive AI messages that are streamed
// separately but represent a single logical message. These utilities are used for
// processing and displaying such sequences in the chat interface.
import { Message } from "@/app/chat/interfaces";
import { DanswerDocument } from "@/lib/search/interfaces";
// Retrieves the consecutive AI messages at the end of the message history.
// This is useful for combining or processing the latest AI response sequence.
export function getConsecutiveAIMessagesAtEnd(
messageHistory: Message[]
): Message[] {
const aiMessages = [];
for (let i = messageHistory.length - 1; i >= 0; i--) {
if (messageHistory[i]?.type === "assistant") {
aiMessages.unshift(messageHistory[i]);
} else {
break;
}
}
return aiMessages;
}
// Extracts unique documents from a sequence of AI messages.
// This is used to compile a comprehensive list of referenced documents
// across multiple parts of an AI response.
export function getUniqueDocumentsFromAIMessages(
messages: Message[]
): DanswerDocument[] {
const uniqueDocumentsMap = new Map<string, DanswerDocument>();
messages.forEach((message) => {
if (message.documents) {
message.documents.forEach((doc) => {
const uniqueKey = `${doc.document_id}-${doc.chunk_ind}`;
if (!uniqueDocumentsMap.has(uniqueKey)) {
uniqueDocumentsMap.set(uniqueKey, doc);
}
});
}
});
return Array.from(uniqueDocumentsMap.values());
}

View File

@@ -101,6 +101,8 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import { SEARCH_TOOL_NAME } from "./tools/constants";
import { useUser } from "@/components/user/UserProvider";
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
import { Button } from "@tremor/react";
import dynamic from "next/dynamic";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -133,7 +135,6 @@ export function ChatPage({
} = useChatContext();
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
const { user, refreshUser, isLoadingUser } = useUser();
// chat session
@@ -248,13 +249,13 @@ export function ChatPage({
if (
lastMessage &&
lastMessage.type === "assistant" &&
lastMessage.toolCalls[0] &&
lastMessage.toolCalls[0].tool_result === undefined
lastMessage.toolCall &&
lastMessage.toolCall.tool_result === undefined
) {
const newCompleteMessageMap = new Map(
currentMessageMap(completeMessageDetail)
);
const updatedMessage = { ...lastMessage, toolCalls: [] };
const updatedMessage = { ...lastMessage, toolCall: null };
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
updateCompleteMessageDetail(currentSession, newCompleteMessageMap);
}
@@ -483,7 +484,7 @@ export function ChatPage({
message: "",
type: "system",
files: [],
toolCalls: [],
toolCall: null,
parentMessageId: null,
childrenMessageIds: [firstMessageId],
latestChildMessageId: firstMessageId,
@@ -510,6 +511,7 @@ export function ChatPage({
}
newCompleteMessageMap.set(message.messageId, message);
});
// if specified, make these new message the latest of the current message chain
if (makeLatestChildMessage) {
const currentMessageChain = buildLatestMessageChain(
@@ -1044,8 +1046,6 @@ export function ChatPage({
resetInputBar();
let messageUpdates: Message[] | null = null;
let answer = "";
let stopReason: StreamStopReason | null = null;
let query: string | null = null;
let retrievalType: RetrievalType =
@@ -1058,12 +1058,14 @@ export function ChatPage({
let stackTrace: string | null = null;
let finalMessage: BackendMessage | null = null;
let toolCalls: ToolCallMetadata[] = [];
let toolCall: ToolCallMetadata | null = null;
let initialFetchDetails: null | {
user_message_id: number;
assistant_message_id: number;
frozenMessageMap: Map<number, Message>;
initialDynamicParentMessage: Message;
initialDynamicAssistantMessage: Message;
} = null;
try {
@@ -1122,7 +1124,16 @@ export function ChatPage({
return new Promise((resolve) => setTimeout(resolve, ms));
};
let updateFn = (messages: Message[]) => {
return upsertToCompleteMessageMap({
messages: messages,
chatSessionId: currChatSessionId,
});
};
await delay(50);
let dynamicParentMessage: Message | null = null;
let dynamicAssistantMessage: Message | null = null;
while (!stack.isComplete || !stack.isEmpty()) {
await delay(0.5);
@@ -1132,6 +1143,7 @@ export function ChatPage({
continue;
}
console.log(packet);
if (!initialFetchDetails) {
if (!Object.hasOwn(packet, "user_message_id")) {
console.error(
@@ -1156,12 +1168,12 @@ export function ChatPage({
messageUpdates = [
{
messageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
? regenerationRequest?.messageId
: user_message_id,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
];
@@ -1176,22 +1188,109 @@ export function ChatPage({
});
}
const { messageMap: currentFrozenMessageMap } =
let { messageMap: currentFrozenMessageMap } =
upsertToCompleteMessageMap({
messages: messageUpdates,
chatSessionId: currChatSessionId,
});
const frozenMessageMap = currentFrozenMessageMap;
let frozenMessageMap = currentFrozenMessageMap;
regenerationRequest?.parentMessage;
let initialDynamicParentMessage: Message = regenerationRequest
? regenerationRequest?.parentMessage
: {
messageId: user_message_id!,
message: "",
type: "user",
files: currentMessageFiles,
toolCall: null,
parentMessageId: error ? null : lastSuccessfulMessageId,
childrenMessageIds: [assistant_message_id!],
latestChildMessageId: -100,
};
let initialDynamicAssistantMessage: Message = {
messageId: assistant_message_id!,
message: "",
type: "assistant",
retrievalType,
query: finalMessage?.rephrased_query || query,
documents: finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCall: finalMessage?.tool_call || toolCall,
parentMessageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: user_message_id,
alternateAssistantID: alternativeAssistant?.id,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
};
initialFetchDetails = {
frozenMessageMap,
assistant_message_id,
user_message_id,
initialDynamicParentMessage,
initialDynamicAssistantMessage,
};
resetRegenerationState();
} else {
const { user_message_id, frozenMessageMap } = initialFetchDetails;
let {
initialDynamicParentMessage,
initialDynamicAssistantMessage,
user_message_id,
frozenMessageMap,
} = initialFetchDetails;
if (
dynamicParentMessage === null &&
dynamicAssistantMessage === null
) {
dynamicParentMessage = initialDynamicParentMessage;
dynamicAssistantMessage = initialDynamicAssistantMessage;
dynamicParentMessage.message = currMessage;
}
if (!dynamicAssistantMessage || !dynamicParentMessage) {
return;
}
if (Object.hasOwn(packet, "user_message_id")) {
let newParentMessageId = dynamicParentMessage.messageId;
const messageResponseIDInfo = packet as MessageResponseIDInfo;
for (const key in dynamicAssistantMessage) {
(dynamicParentMessage as Record<string, any>)[key] = (
dynamicAssistantMessage as Record<string, any>
)[key];
}
dynamicParentMessage.parentMessageId = newParentMessageId;
dynamicParentMessage.latestChildMessageId =
messageResponseIDInfo.reserved_assistant_message_id;
dynamicParentMessage.childrenMessageIds = [
messageResponseIDInfo.reserved_assistant_message_id,
];
dynamicParentMessage.messageId =
messageResponseIDInfo.user_message_id!;
dynamicAssistantMessage = {
messageId: messageResponseIDInfo.reserved_assistant_message_id,
type: "assistant",
message: "",
documents: [],
retrievalType: undefined,
toolCall: null,
files: [],
parentMessageId: dynamicParentMessage.messageId,
childrenMessageIds: [],
latestChildMessageId: null,
};
}
setChatState((prevState) => {
if (prevState.get(chatSessionIdRef.current!) === "loading") {
@@ -1204,37 +1303,37 @@ export function ChatPage({
});
if (Object.hasOwn(packet, "answer_piece")) {
answer += (packet as AnswerPiecePacket).answer_piece;
dynamicAssistantMessage.message += (
packet as AnswerPiecePacket
).answer_piece;
} else if (Object.hasOwn(packet, "top_documents")) {
documents = (packet as DocumentsResponse).top_documents;
dynamicAssistantMessage.documents = (
packet as DocumentsResponse
).top_documents;
dynamicAssistantMessage.retrievalType = RetrievalType.Search;
retrievalType = RetrievalType.Search;
if (documents && documents.length > 0) {
// point to the latest message (we don't know the messageId yet, which is why
// we have to use -1)
setSelectedMessageForDocDisplay(user_message_id);
}
} else if (Object.hasOwn(packet, "tool_name")) {
toolCalls = [
{
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
tool_result: (packet as ToolCallMetadata).tool_result,
},
];
dynamicAssistantMessage.toolCall = {
tool_name: (packet as ToolCallMetadata).tool_name,
tool_args: (packet as ToolCallMetadata).tool_args,
tool_result: (packet as ToolCallMetadata).tool_result,
};
if (
!toolCalls[0].tool_result ||
toolCalls[0].tool_result == undefined
dynamicAssistantMessage.toolCall.tool_name === SEARCH_TOOL_NAME
) {
dynamicAssistantMessage.query =
dynamicAssistantMessage.toolCall.tool_args.query;
}
if (
!dynamicAssistantMessage.toolCall ||
!dynamicAssistantMessage.toolCall.tool_result ||
dynamicAssistantMessage.toolCall.tool_result == undefined
) {
updateChatState("toolBuilding", frozenSessionId);
} else {
updateChatState("streaming", frozenSessionId);
}
// This will be consolidated in upcoming tool calls udpate,
// but for now, we need to set query as early as possible
if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) {
query = toolCalls[0].tool_args["query"];
}
} else if (Object.hasOwn(packet, "file_ids")) {
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
(fileId) => {
@@ -1244,82 +1343,54 @@ export function ChatPage({
};
}
);
dynamicAssistantMessage.files = aiMessageImages;
} else if (Object.hasOwn(packet, "error")) {
error = (packet as StreamingError).error;
stackTrace = (packet as StreamingError).stack_trace;
dynamicAssistantMessage.stackTrace = (
packet as StreamingError
).stack_trace;
} else if (Object.hasOwn(packet, "message_id")) {
finalMessage = packet as BackendMessage;
dynamicAssistantMessage = {
...dynamicAssistantMessage,
...finalMessage,
};
} else if (Object.hasOwn(packet, "stop_reason")) {
const stop_reason = (packet as StreamStopInfo).stop_reason;
if (stop_reason === StreamStopReason.CONTEXT_LENGTH) {
updateCanContinue(true, frozenSessionId);
}
}
if (!Object.hasOwn(packet, "stop_reason")) {
updateFn = (messages: Message[]) => {
const replacementsMap = regenerationRequest
? new Map([
[
regenerationRequest?.parentMessage?.messageId,
regenerationRequest?.parentMessage?.messageId,
],
[
dynamicParentMessage?.messageId,
dynamicAssistantMessage?.messageId,
],
] as [number, number][])
: null;
// on initial message send, we insert a dummy system message
// set this as the parent here if no parent is set
parentMessage =
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
return upsertToCompleteMessageMap({
messages: messages,
replacementsMap: replacementsMap,
completeMessageMapOverride: frozenMessageMap,
chatSessionId: frozenSessionId!,
});
};
const updateFn = (messages: Message[]) => {
const replacementsMap = regenerationRequest
? new Map([
[
regenerationRequest?.parentMessage?.messageId,
regenerationRequest?.parentMessage?.messageId,
],
[
regenerationRequest?.messageId,
initialFetchDetails?.assistant_message_id,
],
] as [number, number][])
: null;
return upsertToCompleteMessageMap({
messages: messages,
replacementsMap: replacementsMap,
completeMessageMapOverride: frozenMessageMap,
chatSessionId: frozenSessionId!,
});
};
updateFn([
{
messageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: initialFetchDetails.user_message_id!,
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
parentMessageId: error ? null : lastSuccessfulMessageId,
childrenMessageIds: [
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
[]),
initialFetchDetails.assistant_message_id!,
],
latestChildMessageId: initialFetchDetails.assistant_message_id,
},
{
messageId: initialFetchDetails.assistant_message_id!,
message: error || answer,
type: error ? "error" : "assistant",
retrievalType,
query: finalMessage?.rephrased_query || query,
documents:
finalMessage?.context_docs?.top_documents || documents,
citations: finalMessage?.citations || {},
files: finalMessage?.files || aiMessageImages || [],
toolCalls: finalMessage?.tool_calls || toolCalls,
parentMessageId: regenerationRequest
? regenerationRequest?.parentMessage?.messageId!
: initialFetchDetails.user_message_id,
alternateAssistantID: alternativeAssistant?.id,
stackTrace: stackTrace,
overridden_model: finalMessage?.overridden_model,
stopReason: stopReason,
},
]);
let { messageMap } = updateFn([
dynamicParentMessage,
dynamicAssistantMessage,
]);
frozenMessageMap = messageMap;
}
}
}
}
@@ -1333,7 +1404,7 @@ export function ChatPage({
message: currMessage,
type: "user",
files: currentMessageFiles,
toolCalls: [],
toolCall: null,
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
},
{
@@ -1343,7 +1414,7 @@ export function ChatPage({
message: errorMsg,
type: "error",
files: aiMessageImages || [],
toolCalls: [],
toolCall: null,
parentMessageId:
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
},
@@ -1962,9 +2033,8 @@ export function ChatPage({
completeMessageDetail
);
const messageReactComponentKey = `${i}-${currentSessionId()}`;
const parentMessage = message.parentMessageId
? messageMap.get(message.parentMessageId)
: null;
const parentMessage =
i > 1 ? messageHistory[i - 1] : null;
if (message.type === "user") {
if (
(currentSessionChatState == "loading" &&
@@ -2055,6 +2125,25 @@ export function ChatPage({
) {
return <></>;
}
const mostRecentNonAIParent = messageHistory
.slice(0, i)
.reverse()
.find((msg) => msg.type !== "assistant");
const hasChildMessage =
message.latestChildMessageId !== null &&
message.latestChildMessageId !== undefined;
const childMessage = hasChildMessage
? messageMap.get(
message.latestChildMessageId!
)
: null;
const hasParentAI =
parentMessage?.type == "assistant";
const hasChildAI =
childMessage?.type == "assistant";
return (
<div
id={`message-${message.messageId}`}
@@ -2066,6 +2155,9 @@ export function ChatPage({
}
>
<AIMessage
setPopup={setPopup}
hasChildAI={hasChildAI}
hasParentAI={hasParentAI}
continueGenerating={
i == messageHistory.length - 1 &&
currentCanContinue()
@@ -2075,7 +2167,7 @@ export function ChatPage({
overriddenModel={message.overridden_model}
regenerate={createRegenerator({
messageId: message.messageId,
parentMessage: parentMessage!,
parentMessage: mostRecentNonAIParent!,
})}
otherMessagesCanSwitchTo={
parentMessage?.childrenMessageIds || []
@@ -2112,18 +2204,15 @@ export function ChatPage({
}
messageId={message.messageId}
content={message.message}
// content={message.message}
files={message.files}
query={
messageHistory[i]?.query || undefined
}
personaName={liveAssistant.name}
citedDocuments={getCitedDocumentsFromMessage(
message
)}
toolCall={
message.toolCalls &&
message.toolCalls[0]
message.toolCall && message.toolCall
}
isComplete={
i !== messageHistory.length - 1 ||
@@ -2147,7 +2236,6 @@ export function ChatPage({
])
}
handleSearchQueryEdit={
i === messageHistory.length - 1 &&
currentSessionChatState == "input"
? (newQuery) => {
if (!previousMessage) {
@@ -2231,7 +2319,6 @@ export function ChatPage({
<AIMessage
currentPersona={liveAssistant}
messageId={message.messageId}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{message.message}
@@ -2279,7 +2366,6 @@ export function ChatPage({
alternativeAssistant
}
messageId={null}
personaName={liveAssistant.name}
content={
<div
key={"Generating"}
@@ -2299,7 +2385,6 @@ export function ChatPage({
<AIMessage
currentPersona={liveAssistant}
messageId={-1}
personaName={liveAssistant.name}
content={
<p className="text-red-700 text-sm my-auto">
{loadingError}

View File

@@ -1,70 +1,78 @@
import { FiFileText } from "react-icons/fi";
import { useState, useRef, useEffect } from "react";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { ExpandTwoIcon } from "@/components/icons/icons";
export function DocumentPreview({
fileName,
maxWidth,
alignBubble,
open,
}: {
fileName: string;
open?: () => void;
maxWidth?: string;
alignBubble?: boolean;
}) {
const [isOverflowing, setIsOverflowing] = useState(false);
const fileNameRef = useRef<HTMLDivElement>(null);
useEffect(() => {
if (fileNameRef.current) {
setIsOverflowing(
fileNameRef.current.scrollWidth > fileNameRef.current.clientWidth
);
}
}, [fileName]);
return (
<div
className={`
${alignBubble && "w-64"}
flex
items-center
p-2
p-3
bg-hover
border
border-border
rounded-md
rounded-lg
box-border
h-16
h-20
hover:shadow-sm
transition-all
`}
>
<div className="flex-shrink-0">
<div
className="
w-12
h-12
w-14
h-14
bg-document
flex
items-center
justify-center
rounded-md
rounded-lg
transition-all
duration-200
hover:bg-document-dark
"
>
<FiFileText className="w-6 h-6 text-white" />
<FiFileText className="w-7 h-7 text-white" />
</div>
</div>
<div className="ml-4 relative">
<div className="ml-4 flex-grow">
<Tooltip content={fileName} side="top" align="start">
<div
ref={fileNameRef}
className={`font-medium text-sm line-clamp-1 break-all ellipses ${
className={`font-medium text-sm line-clamp-1 break-all ellipsis ${
maxWidth ? maxWidth : "max-w-48"
}`}
>
{fileName}
</div>
</Tooltip>
<div className="text-subtle text-sm">Document</div>
<div className="text-subtle text-xs mt-1">Document</div>
</div>
{open && (
<button
onClick={() => open()}
className="ml-2 p-2 rounded-full hover:bg-gray-200 transition-colors duration-200"
aria-label="Expand document"
>
<ExpandTwoIcon className="w-5 h-5 text-gray-600" />
</button>
)}
</div>
);
}

View File

@@ -4,6 +4,7 @@ import {
SearchDanswerDocument,
StreamStopReason,
} from "@/lib/search/interfaces";
import { GraphChunk } from "./message/Messages";
export enum RetrievalType {
None = "none",
@@ -32,6 +33,7 @@ export enum ChatFileType {
IMAGE = "image",
DOCUMENT = "document",
PLAIN_TEXT = "plain_text",
CSV = "csv",
}
export interface FileDescriptor {
@@ -85,7 +87,7 @@ export interface Message {
documents?: DanswerDocument[] | null;
citations?: CitationMap;
files: FileDescriptor[];
toolCalls: ToolCallMetadata[];
toolCall: ToolCallMetadata | null;
// for rebuilding the message tree
parentMessageId: number | null;
childrenMessageIds?: number[];
@@ -94,6 +96,7 @@ export interface Message {
stackTrace?: string | null;
overridden_model?: string;
stopReason?: StreamStopReason | null;
graphs?: GraphChunk[];
}
export interface BackendChatSession {
@@ -120,7 +123,7 @@ export interface BackendMessage {
time_sent: string;
citations: CitationMap;
files: FileDescriptor[];
tool_calls: ToolCallFinalResult[];
tool_call: ToolCallFinalResult | null;
alternate_assistant_id?: number | null;
overridden_model?: string;
}
@@ -143,3 +146,10 @@ export interface StreamingError {
error: string;
stack_trace: string;
}
export interface ImageGenerationResult {
revised_prompt: string;
url: string;
}
export type ImageGenerationResults = ImageGenerationResult[];

View File

@@ -60,6 +60,7 @@ export function getChatRetentionInfo(
showRetentionWarning,
};
}
import { GraphChunk } from "./message/Messages";
export async function updateModelOverrideForChatSession(
chatSessionId: number,
@@ -113,7 +114,8 @@ export type PacketType =
| ImageGenerationDisplay
| StreamingError
| MessageResponseIDInfo
| StreamStopInfo;
| StreamStopInfo
| GraphChunk;
export async function* sendMessage({
regenerate,
@@ -435,7 +437,7 @@ export function processRawChatHistory(
citations: messageInfo?.citations || {},
}
: {}),
toolCalls: messageInfo.tool_calls,
toolCall: messageInfo.tool_call,
parentMessageId: messageInfo.parent_message,
childrenMessageIds: [],
latestChildMessageId: messageInfo.latest_child_message,
@@ -479,6 +481,7 @@ export function buildLatestMessageChain(
let currMessage: Message | null = rootMessage;
while (currMessage) {
finalMessageList.push(currMessage);
const childMessageNumber = currMessage.latestChildMessageId;
if (childMessageNumber && messageMap.has(childMessageNumber)) {
currMessage = messageMap.get(childMessageNumber) as Message;

View File

@@ -0,0 +1,81 @@
import React, { useState } from "react";
import { usePopup } from "@/components/admin/connectors/Popup";
import { Button } from "@/components/Button";
import { uploadFile } from "@/app/admin/assistants/lib";
interface JSONUploadProps {
onUploadSuccess: (jsonData: any) => void;
}
export const JSONUpload: React.FC<JSONUploadProps> = ({ onUploadSuccess }) => {
const { popup, setPopup } = usePopup();
const [credentialJsonStr, setCredentialJsonStr] = useState<
string | undefined
>();
const [file, setFile] = useState<File | null>(null);
const uploadJSON = async () => {
if (!file) {
setPopup({
type: "error",
message: "Please select a file to upload.",
});
return;
}
try {
let parsedData;
if (file.type === "application/json") {
parsedData = JSON.parse(credentialJsonStr!);
} else {
parsedData = credentialJsonStr;
}
const response = await uploadFile(file);
console.log(response);
onUploadSuccess(parsedData);
setPopup({
type: "success",
message: "File uploaded successfully!",
});
} catch (error) {
console.error("Error uploading file:", error);
setPopup({
type: "error",
message: `Failed to upload file - ${error}`,
});
}
};
// 155056ca-dade-4825-bac9-efe86e7bda54
return (
<div>
{popup}
<input
className="mr-3 text-sm text-gray-900 border border-gray-300 rounded-lg cursor-pointer bg-background dark:text-gray-400 focus:outline-none dark:bg-gray-700 dark:border-gray-600 dark:placeholder-gray-400"
type="file"
onChange={(event) => {
if (!event.target.files) {
return;
}
const file = event.target.files[0];
setFile(file);
const reader = new FileReader();
reader.onload = function (loadEvent) {
if (!loadEvent?.target?.result) {
return;
}
const fileContents = loadEvent.target.result;
setCredentialJsonStr(fileContents as string);
};
reader.readAsText(file);
}}
/>
<Button disabled={!credentialJsonStr} onClick={uploadJSON}>
Upload JSON
</Button>
</div>
);
};

View File

@@ -1,6 +1,7 @@
"use client";
import {
FiImage,
FiEdit2,
FiChevronRight,
FiChevronLeft,
@@ -8,25 +9,22 @@ import {
FiGlobe,
} from "react-icons/fi";
import { FeedbackType } from "../types";
import {
Dispatch,
SetStateAction,
useContext,
useEffect,
useRef,
useState,
} from "react";
import { useContext, useEffect, useRef, useState } from "react";
import ReactMarkdown from "react-markdown";
import {
DanswerDocument,
FilteredDanswerDocument,
} from "@/lib/search/interfaces";
import { SearchSummary } from "./SearchSummary";
import { SourceIcon } from "@/components/SourceIcon";
import { SkippedSearch } from "./SkippedSearch";
import remarkGfm from "remark-gfm";
import { CopyButton } from "@/components/CopyButton";
import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
import {
ChatFileType,
FileDescriptor,
ImageGenerationResults,
ToolCallMetadata,
} from "../interfaces";
import {
IMAGE_GENERATION_TOOL_NAME,
SEARCH_TOOL_NAME,
@@ -44,13 +42,11 @@ import "./custom-code-styles.css";
import { Persona } from "@/app/admin/assistants/interfaces";
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
import { Citation } from "@/components/search/results/Citation";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import {
ThumbsUpIcon,
ThumbsDownIcon,
LikeFeedback,
DislikeFeedback,
ToolCallIcon,
} from "@/components/icons/icons";
import {
CustomTooltip,
@@ -59,12 +55,23 @@ import {
import { ValidSources } from "@/lib/types";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { useMouseTracking } from "./hooks";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
import RegenerateOption from "../RegenerateOption";
import { LlmOverride } from "@/lib/hooks";
import { ContinueGenerating } from "./ContinueMessage";
import DualPromptDisplay from "../tools/ImagePromptCitaiton";
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { Popover } from "@/components/popover/Popover";
import { DISABLED_CSV_DISPLAY } from "@/lib/constants";
import {
LineChartDisplay,
ModalChartWrapper,
} from "../../../components/chat_display/graphs/LineChartDisplay";
import BarChartDisplay from "@/components/chat_display/graphs/BarChart";
import ToolResult, {
FileWrapper,
} from "@/components/chat_display/InteractiveToolResult";
const TOOLS_WITH_CUSTOM_HANDLING = [
SEARCH_TOOL_NAME,
@@ -80,8 +87,12 @@ function FileDisplay({
alignBubble?: boolean;
}) {
const imageFiles = files.filter((file) => file.type === ChatFileType.IMAGE);
const nonImgFiles = files.filter((file) => file.type !== ChatFileType.IMAGE);
const nonImgFiles = files.filter(
(file) => file.type !== ChatFileType.IMAGE && file.type !== ChatFileType.CSV
);
const csvImgFiles = files.filter((file) => file.type == ChatFileType.CSV);
const [close, setClose] = useState(true);
return (
<>
{nonImgFiles && nonImgFiles.length > 0 && (
@@ -104,6 +115,36 @@ function FileDisplay({
</div>
</div>
)}
{csvImgFiles && csvImgFiles.length > 0 && (
<div className={` ${alignBubble && "ml-auto"} mt-2 auto mb-4`}>
<div className="flex flex-col gap-2">
{csvImgFiles.map((file) => {
return (
<div key={file.id} className="w-fit">
{close && !DISABLED_CSV_DISPLAY ? (
<>
<ToolResult
csvFileDescriptor={file}
close={() => setClose(false)}
/>
</>
) : (
<DocumentPreview
open={
DISABLED_CSV_DISPLAY ? undefined : () => setClose(true)
}
fileName={file.name || file.id}
maxWidth="max-w-64"
alignBubble={alignBubble}
/>
)}
</div>
);
})}
</div>
</div>
)}
{/* <LineChartDisplay /> */}
{imageFiles && imageFiles.length > 0 && (
<div
id="danswer-image"
@@ -120,7 +161,20 @@ function FileDisplay({
);
}
enum GraphType {
BAR_CHART = "bar_chart",
LINE_GRAPH = "line_graph",
}
export interface GraphChunk {
file_id: string;
plot_data: Record<string, any> | null;
graph_type: GraphType | null;
}
export const AIMessage = ({
hasChildAI,
hasParentAI,
regenerate,
overriddenModel,
continueGenerating,
@@ -129,12 +183,12 @@ export const AIMessage = ({
toggleDocumentSelection,
alternativeAssistant,
docs,
graphs = [],
messageId,
content,
files,
selectedDocuments,
query,
personaName,
citedDocuments,
toolCall,
isComplete,
@@ -148,8 +202,12 @@ export const AIMessage = ({
currentPersona,
otherMessagesCanSwitchTo,
onMessageSelection,
setPopup,
}: {
shared?: boolean;
hasChildAI?: boolean;
hasParentAI?: boolean;
graphs?: GraphChunk[];
isActive?: boolean;
continueGenerating?: () => void;
otherMessagesCanSwitchTo?: number[];
@@ -163,9 +221,8 @@ export const AIMessage = ({
content: string | JSX.Element;
files?: FileDescriptor[];
query?: string;
personaName?: string;
citedDocuments?: [string, DanswerDocument][] | null;
toolCall?: ToolCallMetadata;
toolCall?: ToolCallMetadata | null;
isComplete?: boolean;
hasDocs?: boolean;
handleFeedback?: (feedbackType: FeedbackType) => void;
@@ -176,7 +233,11 @@ export const AIMessage = ({
retrievalDisabled?: boolean;
overriddenModel?: string;
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
setPopup?: (popupSpec: PopupSpec | null) => void;
}) => {
console.log(toolCall);
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
const toolCallGenerating = toolCall && !toolCall.tool_result;
const processContent = (content: string | JSX.Element) => {
if (typeof content !== "string") {
@@ -199,12 +260,20 @@ export const AIMessage = ({
return content;
}
}
if (
isComplete &&
toolCall?.tool_result &&
toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME
) {
return content + ` [${toolCall.tool_name}]()`;
}
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
};
const finalContent = processContent(content as string);
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
const settings = useContext(SettingsContext);
@@ -274,39 +343,50 @@ export const AIMessage = ({
<div
id="danswer-ai-message"
ref={trackedElementRef}
className={"py-5 ml-4 px-5 relative flex "}
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
>
<div
className={`mx-auto ${shared ? "w-full" : "w-[90%]"} max-w-message-max`}
>
<div className={`desktop:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
<div className="flex">
<AssistantIcon
size="small"
assistant={alternativeAssistant || currentPersona}
/>
{!hasParentAI ? (
<AssistantIcon
size="small"
assistant={alternativeAssistant || currentPersona}
/>
) : (
<div className="w-6" />
)}
<div className="w-full">
<div className="max-w-message-max break-words">
<div className="w-full ml-4">
<div className="max-w-message-max break-words">
{!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME ? (
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
<>
{query !== undefined &&
handleShowRetrieved !== undefined &&
isCurrentlyShowingRetrieved !== undefined &&
!retrievalDisabled && (
<div className="mb-1">
<SearchSummary
docs={docs}
filteredDocs={filteredDocs}
query={query}
finished={toolCall?.tool_result != undefined}
hasDocs={hasDocs || false}
messageId={messageId}
handleShowRetrieved={handleShowRetrieved}
finished={
toolCall?.tool_result != undefined ||
isComplete!
}
toggleDocumentSelection={
toggleDocumentSelection
}
handleSearchQueryEdit={handleSearchQueryEdit}
/>
</div>
)}
{handleForceSearch &&
!hasChildAI &&
content &&
query === undefined &&
!hasDocs &&
@@ -318,7 +398,7 @@ export const AIMessage = ({
</div>
)}
</>
) : null}
)}
{toolCall &&
!TOOLS_WITH_CUSTOM_HANDLING.includes(
@@ -344,21 +424,53 @@ export const AIMessage = ({
{toolCall &&
toolCall.tool_name === INTERNET_SEARCH_TOOL_NAME && (
<ToolRunDisplay
toolName={
toolCall.tool_result
? `Searched the internet`
: `Searching the internet`
}
toolLogo={
<FiGlobe size={15} className="my-auto mr-1" />
}
isRunning={!toolCall.tool_result}
/>
<div className="my-2">
<ToolRunDisplay
toolName={
toolCall.tool_result
? `Searched the internet`
: `Searching the internet`
}
toolLogo={
<FiGlobe size={15} className="my-auto mr-1" />
}
isRunning={!toolCall.tool_result}
/>
</div>
)}
{graphs.map((graph, ind) => {
return graph.graph_type === GraphType.LINE_GRAPH ? (
<ModalChartWrapper
key={ind}
chartType="line"
fileId={graph.file_id}
>
<LineChartDisplay fileId={graph.file_id} />
</ModalChartWrapper>
) : (
<ModalChartWrapper
key={ind}
chartType="bar"
fileId={graph.file_id}
>
<BarChartDisplay fileId={graph.file_id} />
</ModalChartWrapper>
);
})}
{content || files ? (
<>
{toolCall?.tool_name == "create_graph" && (
<ModalChartWrapper
key={0}
chartType="line"
fileId={toolCall?.tool_result?.file_id}
>
<LineChartDisplay
fileId={toolCall?.tool_result?.file_id}
/>
</ModalChartWrapper>
)}
<FileDisplay files={files || []} />
{typeof content === "string" ? (
@@ -371,6 +483,50 @@ export const AIMessage = ({
const { node, ...rest } = props;
const value = rest.children;
if (
value
?.toString()
.startsWith(IMAGE_GENERATION_TOOL_NAME)
) {
const imageGenerationResult =
toolCall?.tool_result as ImageGenerationResults;
return (
<Popover
open={isPopoverOpen}
onOpenChange={
() => null
// setIsPopoverOpen(isPopoverOpen => !isPopoverOpen)
} // only allow closing from the icon
content={
<button
onMouseDown={() => {
setIsPopoverOpen(!isPopoverOpen);
}}
>
<ToolCallIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
</button>
}
popover={
<DualPromptDisplay
arg="Prompt"
setPopup={setPopup!}
prompt1={
imageGenerationResult[0]
.revised_prompt
}
prompt2={
imageGenerationResult[1]
.revised_prompt
}
/>
}
side="top"
align="center"
/>
);
}
if (value?.toString().startsWith("*")) {
return (
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
@@ -378,10 +534,6 @@ export const AIMessage = ({
} else if (
value?.toString().startsWith("[")
) {
// for some reason <a> tags cause the onClick to not apply
// and the links are unclickable
// TODO: fix the fact that you have to double click to follow link
// for the first link
return (
<Citation link={rest?.href}>
{rest.children}
@@ -428,82 +580,22 @@ export const AIMessage = ({
) : isComplete ? null : (
<></>
)}
{isComplete && docs && docs.length > 0 && (
<div className="mt-2 -mx-8 w-full mb-4 flex relative">
<div className="w-full">
<div className="px-8 flex gap-x-2">
{!settings?.isMobile &&
filteredDocs.length > 0 &&
filteredDocs.slice(0, 2).map((doc, ind) => (
<div
key={doc.document_id}
className={`w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 pb-2 pt-1 border-b
`}
>
<a
href={doc.link || undefined}
target="_blank"
className="text-sm flex w-full pt-1 gap-x-1.5 overflow-hidden justify-between font-semibold text-text-700"
>
<Citation link={doc.link} index={ind + 1} />
<p className="shrink truncate ellipsis break-all">
{doc.semantic_identifier ||
doc.document_id}
</p>
<div className="ml-auto flex-none">
{doc.is_internet ? (
<InternetSearchIcon url={doc.link} />
) : (
<SourceIcon
sourceType={doc.source_type}
iconSize={18}
/>
)}
</div>
</a>
<div className="flex overscroll-x-scroll mt-.5">
<DocumentMetadataBlock document={doc} />
</div>
<div className="line-clamp-3 text-xs break-words pt-1">
{doc.blurb}
</div>
</div>
))}
<div
onClick={() => {
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={-1}
className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b"
>
<div className="text-sm flex justify-between font-semibold text-text-700">
<p className="line-clamp-1">See context</p>
<div className="flex gap-x-1">
{uniqueSources.map((sourceType, ind) => {
return (
<div key={ind} className="flex-none">
<SourceIcon
sourceType={sourceType}
iconSize={18}
/>
</div>
);
})}
</div>
</div>
<div className="line-clamp-3 text-xs break-words pt-1">
See more
</div>
</div>
</div>
</div>
</div>
)}
</div>
{/* <ModalChartWrapper chartType="line" fileId="fee2ff90-4ebe-43fc-858f-a95c73385da4" >
<LineChartDisplay fileId="fee2ff90-4ebe-43fc-858f-a95c73385da4" />
</ModalChartWrapper> */}
{/*
<ModalChartWrapper chartType="bar" fileId={"0ad36971-9353-42de-b89d-9c3361d3c3eb"}>
<BarChartDisplay fileId={"0ad36971-9353-42de-b89d-9c3361d3c3eb"} />
</ModalChartWrapper>
{handleFeedback &&
<ModalChartWrapper chartType="other" fileId={"066fc31f-56f0-48fb-98d3-ffd46f1ac0f5"}>
<ImageDisplay fileId={"066fc31f-56f0-48fb-98d3-ffd46f1ac0f5"} />
</ModalChartWrapper>
*/}
{!hasChildAI &&
handleFeedback &&
(isActive ? (
<div
className={`

View File

@@ -4,9 +4,21 @@ import {
} from "@/components/BasicClickable";
import { HoverPopup } from "@/components/HoverPopup";
import { Hoverable } from "@/components/Hoverable";
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
import { SourceIcon } from "@/components/SourceIcon";
import { ChevronDownIcon, InfoIcon } from "@/components/icons/icons";
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
import { Citation } from "@/components/search/results/Citation";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { Tooltip } from "@/components/tooltip/Tooltip";
import { useEffect, useRef, useState } from "react";
import {
DanswerDocument,
FilteredDanswerDocument,
} from "@/lib/search/interfaces";
import { ValidSources } from "@/lib/types";
import { useContext, useEffect, useRef, useState } from "react";
import { FiCheck, FiEdit2, FiSearch, FiX } from "react-icons/fi";
import { DownChevron } from "react-select/dist/declarations/src/components/indicators";
export function ShowHideDocsButton({
messageId,
@@ -37,24 +49,31 @@ export function ShowHideDocsButton({
export function SearchSummary({
query,
hasDocs,
filteredDocs,
finished,
messageId,
handleShowRetrieved,
docs,
toggleDocumentSelection,
handleSearchQueryEdit,
}: {
toggleDocumentSelection?: () => void;
docs?: DanswerDocument[] | null;
filteredDocs: FilteredDanswerDocument[];
finished: boolean;
query: string;
hasDocs: boolean;
messageId: number | null;
handleShowRetrieved: (messageId: number | null) => void;
handleSearchQueryEdit?: (query: string) => void;
}) {
const [isEditing, setIsEditing] = useState(false);
const [finalQuery, setFinalQuery] = useState(query);
const [isOverflowed, setIsOverflowed] = useState(false);
const searchingForRef = useRef<HTMLDivElement>(null);
const editQueryRef = useRef<HTMLInputElement>(null);
const [isDropdownOpen, setIsDropdownOpen] = useState(false);
const searchingForRef = useRef<HTMLDivElement | null>(null);
const editQueryRef = useRef<HTMLInputElement | null>(null);
const settings = useContext(SettingsContext);
const toggleDropdown = () => {
setIsDropdownOpen(!isDropdownOpen);
};
useEffect(() => {
const checkOverflow = () => {
@@ -68,7 +87,7 @@ export function SearchSummary({
};
checkOverflow();
window.addEventListener("resize", checkOverflow); // Recheck on window resize
window.addEventListener("resize", checkOverflow);
return () => window.removeEventListener("resize", checkOverflow);
}, []);
@@ -86,15 +105,30 @@ export function SearchSummary({
}, [query]);
const searchingForDisplay = (
<div className={`flex p-1 rounded ${isOverflowed && "cursor-default"}`}>
<FiSearch className="flex-none mr-2 my-auto" size={14} />
<div
className={`${!finished && "loading-text"}
!text-sm !line-clamp-1 !break-all px-0.5`}
ref={searchingForRef}
>
{finished ? "Searched" : "Searching"} for: <i> {finalQuery}</i>
</div>
<div
className={`flex my-auto items-center ${isOverflowed && "cursor-default"}`}
>
{finished ? (
<>
<div
onClick={() => {
toggleDropdown();
}}
className={`transition-colors duration-300 group-hover:text-text-toolhover cursor-pointer text-text-toolrun !line-clamp-1 !break-all pr-0.5`}
ref={searchingForRef}
>
Searched {filteredDocs.length > 0 && filteredDocs.length} document
{filteredDocs.length != 1 && "s"} for {query}
</div>
</>
) : (
<div
className={`loading-text !text-sm !line-clamp-1 !break-all px-0.5`}
ref={searchingForRef}
>
Searching for: <i> {finalQuery}</i>
</div>
)}
</div>
);
@@ -145,43 +179,126 @@ export function SearchSummary({
</div>
</div>
) : null;
const SearchBlock = ({ doc, ind }: { doc: DanswerDocument; ind: number }) => {
return (
<div
onClick={() => {
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={doc.document_id}
className={`flex items-start gap-3 px-4 py-3 text-token-text-secondary ${ind == 0 && "rounded-t-xl"} hover:bg-background-100 group relative text-sm`}
>
<div className="mt-1 scale-[.9] flex-none">
{doc.is_internet ? (
<InternetSearchIcon url={doc.link} />
) : (
<SourceIcon sourceType={doc.source_type} iconSize={18} />
)}
</div>
<div className="flex flex-col">
<a
href={doc.link}
target="_blank"
className="line-clamp-1 text-text-900"
>
<p className="shrink truncate ellipsis break-all ">
{doc.semantic_identifier || doc.document_id}
</p>
<p className="line-clamp-3 text-text-500 break-words">
{doc.blurb}
</p>
</a>
</div>
</div>
);
};
return (
<div className="flex">
{isEditing ? (
editInput
) : (
<>
<div className="text-sm">
{isOverflowed ? (
<HoverPopup
mainContent={searchingForDisplay}
popupContent={
<div>
<b>Full query:</b>{" "}
<div className="mt-1 italic">{query}</div>
</div>
}
direction="top"
<>
<div className="flex gap-x-2 group">
{isEditing ? (
editInput
) : (
<>
<div className="my-auto text-sm">
{isOverflowed ? (
<HoverPopup
mainContent={searchingForDisplay}
popupContent={
<div>
<b>Full query:</b>{" "}
<div className="mt-1 italic">{query}</div>
</div>
}
direction="top"
/>
) : (
searchingForDisplay
)}
</div>
<button
className="my-auto invisible group-hover:visible transition-all duration-300 rounded"
onClick={toggleDropdown}
>
<ChevronDownIcon
className={`transform transition-transform ${isDropdownOpen ? "rotate-180" : ""}`}
/>
</button>
{handleSearchQueryEdit ? (
<Tooltip delayDuration={1000} content={"Edit Search"}>
<button
className="my-auto invisible group-hover:visible transition-all duration-300 cursor-pointer rounded"
onClick={() => {
setIsEditing(true);
}}
>
<FiEdit2 />
</button>
</Tooltip>
) : (
searchingForDisplay
<></>
)}
</div>
{handleSearchQueryEdit && (
<Tooltip delayDuration={1000} content={"Edit Search"}>
<button
className="my-auto hover:bg-hover p-1.5 rounded"
</>
)}
</div>
{isDropdownOpen && docs && docs.length > 0 && (
<div
className={`mt-2 -mx-8 w-full mb-4 flex relative transition-all duration-300 ${isDropdownOpen ? "opacity-100 max-h-[1000px]" : "opacity-0 max-h-0"}`}
>
<div className="w-full">
<div className="mx-8 flex rounded max-h-[500px] overflow-y-scroll rounded-lg border-1.5 border divide-y divider-y-1.5 divider-y-border border-border flex-col gap-x-4">
{!settings?.isMobile &&
filteredDocs.length > 0 &&
filteredDocs.map((doc, ind) => (
<SearchBlock key={ind} doc={doc} ind={ind} />
))}
<div
onClick={() => {
setIsEditing(true);
if (toggleDocumentSelection) {
toggleDocumentSelection();
}
}}
key={-1}
className="cursor-pointer w-full flex transition-all duration-500 hover:bg-background-100 py-3 border-b"
>
<FiEdit2 />
</button>
</Tooltip>
)}
</>
<div key={0} className="px-3 invisible scale-[.9] flex-none">
<SourceIcon sourceType={"file"} iconSize={18} />
</div>
<div className="text-sm flex justify-between text-text-900">
<p className="line-clamp-1">See context</p>
<div className="flex gap-x-1"></div>
</div>
</div>
</div>
</div>
</div>
)}
</div>
</>
);
}

View File

@@ -101,7 +101,6 @@ export function SharedChatDisplay({
messageId={message.messageId}
content={message.message}
files={message.files || []}
personaName={chatSession.persona_name}
citedDocuments={getCitedDocumentsFromMessage(message)}
isComplete
/>

View File

@@ -0,0 +1,83 @@
import { PopupSpec } from "@/components/admin/connectors/Popup";
import { CopyIcon } from "@/components/icons/icons";
import { Divider } from "@tremor/react";
import React, { forwardRef, useState } from "react";
import { FiCheck } from "react-icons/fi";
interface PromptDisplayProps {
prompt1: string;
prompt2?: string;
arg: string;
setPopup: (popupSpec: PopupSpec | null) => void;
}
const DualPromptDisplay = forwardRef<HTMLDivElement, PromptDisplayProps>(
({ prompt1, prompt2, setPopup, arg }, ref) => {
const [copied, setCopied] = useState<number | null>(null);
const copyToClipboard = (text: string, index: number) => {
navigator.clipboard
.writeText(text)
.then(() => {
setPopup({ message: "Copied to clipboard", type: "success" });
setCopied(index);
setTimeout(() => setCopied(null), 2000); // Reset copy status after 2 seconds
})
.catch((err) => {
setPopup({ message: "Failed to copy", type: "error" });
});
};
const PromptSection = ({
copied,
prompt,
index,
}: {
copied: number | null;
prompt: string;
index: number;
}) => (
<div className="w-full p-2 rounded-lg">
<h2 className="text-lg font-semibold mb-2">
{arg} {index + 1}
</h2>
<p className="line-clamp-6 text-sm text-gray-800">{prompt}</p>
<button
onMouseDown={() => copyToClipboard(prompt, index)}
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
>
{copied == index ? (
<>
<FiCheck className="mr-2" size={16} />
Copied!
</>
) : (
<>
<CopyIcon className="mr-2" size={16} />
Copy
</>
)}
</button>
</div>
);
return (
<div className="w-[400px] bg-inverted mx-auto p-6 rounded-lg shadow-lg">
<div className="flex flex-col gap-x-4">
<PromptSection copied={copied} prompt={prompt1} index={0} />
{prompt2 && (
<>
<Divider />
<PromptSection copied={copied} prompt={prompt2} index={1} />
</>
)}
</div>
</div>
);
}
);
DualPromptDisplay.displayName = "DualPromptDisplay";
export default DualPromptDisplay;

View File

@@ -15,6 +15,7 @@ interface ModalProps {
titleSize?: string;
hideDividerForTitle?: boolean;
noPadding?: boolean;
hideCloseButton?: boolean;
}
export function Modal({
@@ -27,6 +28,7 @@ export function Modal({
hideDividerForTitle,
noPadding,
icon,
hideCloseButton,
}: ModalProps) {
const modalRef = useRef<HTMLDivElement>(null);
@@ -60,7 +62,7 @@ export function Modal({
${noPadding ? "" : "p-10"}
${className || ""}`}
>
{onOutsideClick && (
{onOutsideClick && !hideCloseButton && (
<div className="absolute top-2 right-2">
<button
onClick={onOutsideClick}

View File

@@ -145,6 +145,19 @@ export function ClientLayout({
),
link: "/admin/tools",
},
...(enableEnterprise
? [
{
name: (
<div className="flex">
<ClipboardIcon size={18} />
<div className="ml-1">Standard Answers</div>
</div>
),
link: "/admin/standard-answer",
},
]
: []),
{
name: (
<div className="flex">

View File

@@ -0,0 +1,152 @@
import React, { useState, useEffect } from "react";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import { FileDescriptor } from "@/app/chat/interfaces";
import { WarningCircle } from "@phosphor-icons/react";
interface CSVData {
[key: string]: string;
}
export interface ToolDisplay {
fileDescriptor: FileDescriptor;
isLoading: boolean;
fadeIn: boolean;
}
export const CsvContent = ({
fileDescriptor,
isLoading,
fadeIn,
}: ToolDisplay) => {
const [data, setData] = useState<CSVData[]>([]);
const [headers, setHeaders] = useState<string[]>([]);
useEffect(() => {
fetchCSV(fileDescriptor.id);
}, [fileDescriptor.id]);
const fetchCSV = async (id: string) => {
try {
const response = await fetch(`api/chat/file/${id}`);
if (!response.ok) {
throw new Error("Failed to fetch CSV file");
}
const contentLength = response.headers.get("Content-Length");
const fileSizeInMB = contentLength
? parseInt(contentLength) / (1024 * 1024)
: 0;
const MAX_FILE_SIZE_MB = 5;
if (fileSizeInMB > MAX_FILE_SIZE_MB) {
throw new Error("File size exceeds the maximum limit of 5MB");
}
const csvData = await response.text();
const rows = csvData.trim().split("\n");
const parsedHeaders = rows[0].split(",");
setHeaders(parsedHeaders);
const parsedData: CSVData[] = rows.slice(1).map((row) => {
const values = row.split(",");
return parsedHeaders.reduce<CSVData>((obj, header, index) => {
obj[header] = values[index];
return obj;
}, {});
});
setData(parsedData);
} catch (error) {
console.error("Error fetching CSV file:", error);
setData([]);
setHeaders([]);
}
};
if (isLoading) {
return (
<div className="flex items-center justify-center h-[300px]">
<div className="animate-pulse w- flex space-x-4">
<div className="rounded-full bg-background-200 h-10 w-10"></div>
<div className="w-full flex-1 space-y-4 py-1">
<div className="h-2 w-full bg-background-200 rounded"></div>
<div className="w-full space-y-3">
<div className="grid grid-cols-3 gap-4">
<div className="h-2 bg-background-200 rounded col-span-2"></div>
<div className="h-2 bg-background-200 rounded col-span-1"></div>
</div>
<div className="h-2 bg-background-200 rounded"></div>
</div>
</div>
</div>
</div>
);
}
return (
<div
className={`transition-opacity transform relative duration-1000 ease-in-out ${fadeIn ? "opacity-100" : "opacity-0"}`}
>
<div className={`overflow-y-auto flex relative max-h-[400px]`}>
<Table className="!relative !overflow-y-scroll">
<TableHeader className="z-20 !sticky !top-0">
<TableRow className="!bg-neutral-100">
{headers.map((header, index) => (
<TableHead className=" " key={index}>
<p className="text-text-600 line-clamp-2 my-2 font-medium">
{index === 0 ? "" : header}
</p>
</TableHead>
))}
</TableRow>
</TableHeader>
<TableBody>
{data.length > 0 ? (
data.map((row, rowIndex) => (
<TableRow key={rowIndex}>
{headers.map((header, cellIndex) => (
<TableCell
className={`${
cellIndex === 0 && "sticky left-0 !bg-neutral-100"
}`}
key={cellIndex}
>
{row[header]}
</TableCell>
))}
</TableRow>
))
) : (
<TableRow>
<TableCell
colSpan={headers.length}
className="text-center py-8"
>
<div className="flex flex-col items-center justify-center space-y-2">
<WarningCircle className="w-8 h-8 text-error" />
<p className="text-text-600 font-medium">
{headers.length === 0
? "Error loading CSV"
: "No data available"}
</p>
<p className="text-text-400 text-sm">
{headers.length === 0
? "The CSV file may be too large or couldn't be loaded properly."
: "The CSV file appears to be empty."}
</p>
</div>
</TableCell>
</TableRow>
)}
</TableBody>
</Table>
</div>
</div>
);
};

View File

@@ -0,0 +1,143 @@
import React, { useState, useEffect } from "react";
import {
CustomTooltip,
TooltipGroup,
} from "@/components/tooltip/CustomTooltip";
import {
DexpandTwoIcon,
DownloadCSVIcon,
ExpandTwoIcon,
OpenIcon,
} from "@/components/icons/icons";
import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card";
import { Modal } from "@/components/Modal";
import { FileDescriptor } from "@/app/chat/interfaces";
import { CsvContent, ToolDisplay } from "./CSVContent";
export default function ToolResult({
csvFileDescriptor,
close,
}: {
csvFileDescriptor: FileDescriptor;
close: () => void;
}) {
const [expanded, setExpanded] = useState(false);
const expand = () => setExpanded((prev) => !prev);
return (
<>
{expanded && (
<Modal
hideCloseButton
onOutsideClick={() => setExpanded(false)}
className="!max-w-5xl overflow-hidden rounded-lg animate-all ease-in !p-0"
>
<FileWrapper
fileDescriptor={csvFileDescriptor}
close={close}
expanded={true}
expand={expand}
ContentComponent={CsvContent}
/>
</Modal>
)}
<FileWrapper
fileDescriptor={csvFileDescriptor}
close={close}
expanded={false}
expand={expand}
ContentComponent={CsvContent}
/>
</>
);
}
interface FileWrapperProps {
fileDescriptor: FileDescriptor;
close: () => void;
expanded: boolean;
expand: () => void;
ContentComponent: React.ComponentType<ToolDisplay>;
}
export const FileWrapper = ({
fileDescriptor,
close,
expanded,
expand,
ContentComponent,
}: FileWrapperProps) => {
const [isLoading, setIsLoading] = useState(true);
const [fadeIn, setFadeIn] = useState(false);
useEffect(() => {
// Simulate loading
setTimeout(() => setIsLoading(false), 300);
}, []);
useEffect(() => {
if (!isLoading) {
setTimeout(() => setFadeIn(true), 50);
} else {
setFadeIn(false);
}
}, [isLoading]);
const downloadFile = () => {
// Implement download logic here
};
return (
<div
className={`${
!expanded ? "w-message-sm" : "w-full"
} !rounded !rounded-lg overflow-y-hidden w-full border border-border`}
>
<CardHeader className="w-full !py-0 !pb-4 border-b border-border border-b-neutral-200 !pt-4 !mb-0 z-[10] top-0">
<div className="flex justify-between items-center">
<CardTitle className="!my-auto text-ellipsis line-clamp-1 text-xl font-semibold text-text-700 pr-4 transition-colors duration-300">
{fileDescriptor.name}
</CardTitle>
<div className="flex !my-auto">
<TooltipGroup gap="gap-x-4">
<CustomTooltip showTick line content="Download file">
<button onClick={downloadFile}>
<DownloadCSVIcon className="cursor-pointer transition-colors duration-300 hover:text-text-800 h-6 w-6 text-text-400" />
</button>
</CustomTooltip>
<CustomTooltip
line
showTick
content={expanded ? "Minimize" : "Full screen"}
>
<button onClick={expand}>
{!expanded ? (
<ExpandTwoIcon className="transition-colors duration-300 hover:text-text-800 h-6 w-6 cursor-pointer text-text-400" />
) : (
<DexpandTwoIcon className="transition-colors duration-300 hover:text-text-800 h-6 w-6 cursor-pointer text-text-400" />
)}
</button>
</CustomTooltip>
<CustomTooltip showTick line content="Hide">
<button onClick={close}>
<OpenIcon className="transition-colors duration-300 hover:text-text-800 h-6 w-6 cursor-pointer text-text-400" />
</button>
</CustomTooltip>
</TooltipGroup>
</div>
</div>
</CardHeader>
<Card className="!rounded-none w-full max-h-[600px] !p-0 relative overflow-x-scroll overflow-y-scroll mx-auto">
<CardContent className="!p-0">
<ContentComponent
fileDescriptor={fileDescriptor}
isLoading={isLoading}
fadeIn={fadeIn}
/>
</CardContent>
</Card>
</div>
);
};

View File

@@ -0,0 +1,89 @@
import React, { useState, useEffect } from "react";
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip } from "recharts";
import { CardContent } from "@/components/ui/card";
interface BarDataPoint {
x: number;
y: number;
width: number;
color: string;
}
interface BarPlotData {
data: BarDataPoint[];
title: string;
xlabel: string;
ylabel: string;
xticks: number[];
xticklabels: string[];
}
export function BarChartDisplay({ fileId }: { fileId: string }) {
const [barPlotData, setBarPlotData] = useState<BarPlotData | null>(null);
useEffect(() => {
fetchPlotData(fileId);
}, [fileId]);
const fetchPlotData = async (id: string) => {
try {
const response = await fetch(`api/chat/file/${id}`);
if (!response.ok) {
throw new Error("Failed to fetch plot data");
}
const data: BarPlotData = await response.json();
setBarPlotData(data);
} catch (error) {
console.error("Error fetching plot data:", error);
}
};
if (!barPlotData) {
return <div>Loading...</div>;
}
console.log("IN THE FUNCTION");
// Transform data to match Recharts expected format
const transformedData = barPlotData.data.map((point, index) => ({
name: barPlotData.xticklabels[index] || point.x.toString(),
value: point.y,
}));
return (
<>
<h2>{barPlotData.title}</h2>
<BarChart
width={600}
height={400}
data={transformedData}
margin={{
top: 20,
right: 30,
left: 20,
bottom: 5,
}}
>
<CartesianGrid strokeDasharray="3 3" />
<XAxis
dataKey="name"
label={{
value: barPlotData.xlabel,
position: "insideBottom",
offset: -10,
}}
/>
<YAxis
label={{
value: barPlotData.ylabel,
angle: -90,
position: "insideLeft",
}}
/>
<Tooltip />
<Bar dataKey="value" fill={barPlotData.data[0].color} />
</BarChart>
</>
);
}
export default BarChartDisplay;

View File

@@ -0,0 +1,56 @@
import { buildImgUrl } from "@/app/chat/files/images/utils";
import React, { useState, useEffect } from "react";
export function ImageDisplay({ fileId }: { fileId: string }) {
const [imageUrl, setImageUrl] = useState<string | null>(null);
const [fullImageShowing, setFullImageShowing] = useState(false);
// useEffect(() => {
// fetchImageUrl(fileId);
// }, [fileId]);
// const fetchImageUrl = async (id: string) => {
// try {
// const response = await fetch(`api/chat/file/${id}`);
// if (!response.ok) {
// throw new Error('Failed to fetch image data');
// }
// const data = await response.json();
// setImageUrl(data.imageUrl); // Assuming the API returns an object with an imageUrl field
// } catch (error) {
// console.error("Error fetching image data:", error);
// }
// };
// const buildImgUrl = (id: string) => {
// // Implement your URL building logic here
// return imageUrl || ''; // Return the fetched URL or an empty string if not available
// };
return (
// <div className="w-full h-full">
<>
<img
className="w-full mx-auto object-cover object-center overflow-hidden rounded-lg w-full h-full transition-opacity duration-300 opacity-100"
onClick={() => setFullImageShowing(true)}
src={buildImgUrl(fileId)}
alt="Fetched image"
loading="lazy"
/>
{fullImageShowing && (
<div
className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50"
onClick={() => setFullImageShowing(false)}
>
<img
src={buildImgUrl(fileId)}
alt="Full size image"
className="max-w-90vw max-h-90vh object-contain"
/>
</div>
)}
</>
// </div>
);
}

View File

@@ -0,0 +1,261 @@
import React, { useEffect, useState } from "react";
import { PickaxeIcon, TrendingUp } from "lucide-react";
import { CartesianGrid, Line, LineChart, XAxis, YAxis } from "recharts";
import {
ChartConfig,
ChartContainer,
ChartTooltip,
ChartTooltipContent,
} from "@/components/ui/chart";
import { CardFooter } from "@/components/ui/card";
import {
DexpandTwoIcon,
DownloadCSVIcon,
ExpandTwoIcon,
OpenIcon,
PaintingIcon,
PaintingIconSkeleton,
} from "@/components/icons/icons";
import {
CustomTooltip,
TooltipGroup,
} from "@/components/tooltip/CustomTooltip";
import { Modal } from "@/components/Modal";
import { ChartDataPoint, ChartType, Data, PlotData } from "./types";
import { SelectionBackground } from "@phosphor-icons/react";
export function ModalChartWrapper({
children,
fileId,
chartType,
}: {
children: JSX.Element;
fileId: string;
chartType: ChartType;
}) {
const [expanded, setExpanded] = useState(false);
const expand = () => {
setExpanded((expanded) => !expanded);
};
return (
<>
{expanded ? (
<Modal
onOutsideClick={() => setExpanded(false)}
className="animate-all ease-in !p-0"
>
<ChartWrapper
chartType={chartType}
expand={expand}
expanded={expanded}
fileId={fileId}
>
{children}
</ChartWrapper>
</Modal>
) : (
<ChartWrapper
chartType={chartType}
expand={expand}
expanded={expanded}
fileId={fileId}
>
{children}
</ChartWrapper>
)}
</>
);
}
export function ChartWrapper({
expanded,
children,
fileId,
chartType,
expand,
}: {
expanded: boolean;
children: JSX.Element;
chartType: ChartType;
fileId: string;
expand: () => void;
}) {
const [plotDataJson, setPlotDataJson] = useState<Data | null>(null);
useEffect(() => {
fetchPlotData(fileId);
}, [fileId]);
const fetchPlotData = async (id: string) => {
if (chartType == "other") {
setPlotDataJson({ title: "Uploaded Chart" });
} else {
try {
const response = await fetch(`api/chat/file/${id}`);
if (!response.ok) {
throw new Error("Failed to fetch plot data");
}
const data = await response.json();
setPlotDataJson(data);
} catch (error) {
console.error("Error fetching plot data:", error);
}
}
};
const downloadFile = () => {
if (!fileId) return;
// Implement download functionality here
};
if (!plotDataJson) {
return <div>Loading...</div>;
}
return (
<div className="bg-background-50 group rounded-lg shadow-md relative">
<div className="relative p-4">
<div className="relative flex pb-2 items-center justify-between">
<h2 className="text-xl font-semibold mb-2">{plotDataJson.title}</h2>
<div className="flex gap-x-2">
<TooltipGroup>
{chartType != "other" && (
<CustomTooltip
showTick
line
position="top"
content="View Static file"
>
<button onClick={() => downloadFile()}>
<PaintingIconSkeleton className="cursor-pointer transition-colors duration-300 hover:text-neutral-800 h-6 w-6 text-neutral-400" />
</button>
</CustomTooltip>
)}
<CustomTooltip
showTick
line
position="top"
content="Download file"
>
<button onClick={() => downloadFile()}>
<DownloadCSVIcon className="cursor-pointer ml-4 transition-colors duration-300 hover:text-neutral-800 h-6 w-6 text-neutral-400" />
</button>
</CustomTooltip>
<CustomTooltip
line
position="top"
content={expanded ? "Minimize" : "Full screen"}
>
<button onClick={() => expand()}>
{!expanded ? (
<ExpandTwoIcon className="transition-colors duration-300 ml-4 hover:text-neutral-800 h-6 w-6 cursor-pointer text-neutral-400" />
) : (
<DexpandTwoIcon className="transition-colors duration-300 ml-4 hover:text-neutral-800 h-6 w-6 cursor-pointer text-neutral-400" />
)}
</button>
</CustomTooltip>
</TooltipGroup>
</div>
</div>
{children}
{chartType === "other" && (
<div className="absolute bottom-6 right-6 p-1.5 rounded flex gap-x-2 items-center border border-neutral-200 bg-neutral-50 opacity-0 transition-opacity duration-300 ease-in-out text-sm text-gray-500 group-hover:opacity-100">
<SelectionBackground />
Interactive charts of this type are not supported yet
</div>
)}
</div>
<CardFooter className="flex-col items-start gap-2 text-sm">
<div className="flex gap-2 font-medium leading-none">
Data from Matplotlib plot <TrendingUp className="h-4 w-4" />
</div>
</CardFooter>
</div>
);
}
export function LineChartDisplay({ fileId }: { fileId: string }) {
const [chartData, setChartData] = useState<ChartDataPoint[]>([]);
const [chartConfig, setChartConfig] = useState<ChartConfig>({});
useEffect(() => {
fetchPlotData(fileId);
}, [fileId]);
const fetchPlotData = async (id: string) => {
try {
const response = await fetch(`api/chat/file/${id}`);
if (!response.ok) {
throw new Error("Failed to fetch plot data");
}
const plotDataJson: PlotData = await response.json();
console.log("plot data");
console.log(plotDataJson);
const transformedData: ChartDataPoint[] = plotDataJson.data[0].x.map(
(x, index) => ({
x: x,
y: plotDataJson.data[0].y[index],
})
);
setChartData(transformedData);
const config: ChartConfig = {
y: {
label: plotDataJson.data[0].label,
color: plotDataJson.data[0].color,
},
};
setChartConfig(config);
} catch (error) {
console.error("Error fetching plot data:", error);
}
};
console.log("chartData");
console.log(chartData);
return (
<div className="w-full h-full">
<ChartContainer config={chartConfig}>
<LineChart
accessibilityLayer
data={chartData}
margin={{
left: 12,
right: 12,
}}
>
<CartesianGrid vertical={false} />
<XAxis
dataKey="x"
tickLine={false}
axisLine={false}
tickMargin={8}
tickFormatter={(value: number) => value.toFixed(2)}
/>
<YAxis
dataKey="y"
tickLine={false}
axisLine={false}
tickMargin={8}
tickFormatter={(value: number) => value.toFixed(2)}
/>
<ChartTooltip
cursor={false}
content={<ChartTooltipContent hideLabel />}
/>
<Line
dataKey="y"
type="natural"
stroke={chartConfig.y?.color || "var(--chart-1)"}
strokeWidth={2}
dot={false}
/>
</LineChart>
</ChartContainer>
</div>
);
}

View File

@@ -0,0 +1,93 @@
import React, { useEffect, useState } from "react";
import { TrendingUp } from "lucide-react";
import {
PolarAngleAxis,
PolarGrid,
PolarRadiusAxis,
Radar,
RadarChart,
ResponsiveContainer,
} from "recharts";
import { CardContent } from "@/components/ui/card";
import {
ChartConfig,
ChartContainer,
ChartTooltip,
ChartTooltipContent,
} from "@/components/ui/chart";
import { PolarChartDataPoint, PolarPlotData } from "./types";
export function PolarChartDisplay({ fileId }: { fileId: string }) {
const [chartData, setChartData] = useState<PolarChartDataPoint[]>([]);
const [chartConfig, setChartConfig] = useState<ChartConfig>({});
const [plotDataJson, setPlotDataJson] = useState<PolarPlotData | null>(null);
useEffect(() => {
fetchPlotData(fileId);
}, [fileId]);
const fetchPlotData = async (id: string) => {
try {
const response = await fetch(`api/chat/file/${id}`);
if (!response.ok) {
throw new Error("Failed to fetch plot data");
}
const data: PolarPlotData = await response.json();
setPlotDataJson(data);
// Transform the JSON data to the format expected by the chart
const transformedData: PolarChartDataPoint[] = data.data[0].theta.map(
(angle, index) => ({
angle: (angle * 180) / Math.PI, // Convert radians to degrees
radius: data.data[0].r[index],
})
);
setChartData(transformedData);
// Create the chart config from the JSON data
const config: ChartConfig = {
y: {
label: data.data[0].label,
color: data.data[0].color,
},
};
setChartConfig(config);
} catch (error) {
console.error("Error fetching plot data:", error);
}
};
if (!plotDataJson) {
return <div>Loading...</div>;
}
return (
<CardContent>
<ChartContainer config={chartConfig}>
<ResponsiveContainer width="100%" height={400}>
<RadarChart cx="50%" cy="50%" outerRadius="80%" data={chartData}>
<PolarGrid />
<PolarAngleAxis dataKey="angle" type="number" domain={[0, 360]} />
<PolarRadiusAxis angle={30} domain={[0, plotDataJson.rmax]} />
<ChartTooltip
cursor={false}
content={<ChartTooltipContent hideLabel />}
/>
<Radar
name="Polar Plot"
dataKey="radius"
stroke={chartConfig.y?.color || "var(--chart-1)"}
fill={chartConfig.y?.color || "var(--chart-1)"}
fillOpacity={0.6}
/>
</RadarChart>
</ResponsiveContainer>
</ChartContainer>
</CardContent>
);
}
export default PolarChartDisplay;

View File

@@ -0,0 +1,37 @@
export interface Data {
title: string;
}
export interface PlotData extends Data {
data: Array<{
x: number[];
y: number[];
label: string;
color: string;
}>;
xlabel: string;
ylabel: string;
}
export interface ChartDataPoint {
x: number;
y: number;
}
export interface PolarPlotData extends Data {
data: Array<{
theta: number[];
r: number[];
label: string;
color: string;
}>;
rmax: number;
rticks: number[];
rlabel_position: number;
}
export interface PolarChartDataPoint {
angle: number;
radius: number;
}
export type ChartType = "line" | "bar" | "radial" | "other";

View File

@@ -122,6 +122,102 @@ export const AssistantsIconSkeleton = ({
);
};
export const OpenIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 14 14"
>
<path
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
d="M7 13.5a9.26 9.26 0 0 0-5.61-2.95a1 1 0 0 1-.89-1V1.5A1 1 0 0 1 1.64.51A9.3 9.3 0 0 1 7 3.43zm0 0a9.26 9.26 0 0 1 5.61-2.95a1 1 0 0 0 .89-1V1.5a1 1 0 0 0-1.14-.99A9.3 9.3 0 0 0 7 3.43z"
/>
</svg>
);
};
export const DexpandTwoIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 14 14"
>
<path
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
d="m.5 13.5l5-5m-4 0h4v4m8-12l-5 5m4 0h-4v-4"
/>
</svg>
);
};
export const ExpandTwoIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 14 14"
>
<path
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
d="m8.5 5.5l5-5m-4 0h4v4m-8 4l-5 5m4 0h-4v-4"
/>
</svg>
);
};
export const DownloadCSVIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
width="200"
height="200"
viewBox="0 0 14 14"
>
<path
fill="none"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
d="M.5 10.5v1a2 2 0 0 0 2 2h9a2 2 0 0 0 2-2v-1M4 6l3 3.5L10 6M7 9.5v-9"
/>
</svg>
);
};
export const LightBulbIcon = ({
size,
className = defaultTailwindCSS,
@@ -2811,3 +2907,40 @@ export const WindowsIcon = ({
</svg>
);
};
export const ToolCallIcon = ({
size = 16,
className = defaultTailwindCSS,
}: IconProps) => {
return (
<svg
style={{ width: `${size}px`, height: `${size}px` }}
className={`w-[${size}px] h-[${size}px] ` + className}
xmlns="http://www.w3.org/2000/svg"
viewBox="0 0 19 15"
fill="none"
>
<path
d="M4.42 0.75H2.8625H2.75C1.64543 0.75 0.75 1.64543 0.75 2.75V11.65C0.75 12.7546 1.64543 13.65 2.75 13.65H2.8625C2.8625 13.65 2.8625 13.65 2.8625 13.65C2.8625 13.65 4.00751 13.65 4.42 13.65M13.98 13.65H15.5375H15.65C16.7546 13.65 17.65 12.7546 17.65 11.65V2.75C17.65 1.64543 16.7546 0.75 15.65 0.75H15.5375H13.98"
stroke="currentColor"
strokeWidth="1.5"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M5.55283 4.21963C5.25993 3.92674 4.78506 3.92674 4.49217 4.21963C4.19927 4.51252 4.19927 4.9874 4.49217 5.28029L6.36184 7.14996L4.49217 9.01963C4.19927 9.31252 4.19927 9.7874 4.49217 10.0803C4.78506 10.3732 5.25993 10.3732 5.55283 10.0803L7.95283 7.68029C8.24572 7.3874 8.24572 6.91252 7.95283 6.61963L5.55283 4.21963Z"
fill="currentColor"
stroke="currentColor"
strokeWidth="0.2"
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M9.77753 8.75003C9.3357 8.75003 8.97753 9.10821 8.97753 9.55003C8.97753 9.99186 9.3357 10.35 9.77753 10.35H13.2775C13.7194 10.35 14.0775 9.99186 14.0775 9.55003C14.0775 9.10821 13.7194 8.75003 13.2775 8.75003H9.77753Z"
fill="currentColor"
stroke="currentColor"
strokeWidth="0.1"
/>
</svg>
);
};

View File

@@ -3,7 +3,6 @@
import {
DanswerDocument,
DocumentRelevance,
Relevance,
SearchDanswerDocument,
} from "@/lib/search/interfaces";
import { DocumentFeedbackBlock } from "./DocumentFeedbackBlock";
@@ -12,11 +11,10 @@ import { PopupSpec } from "../admin/connectors/Popup";
import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge";
import { SourceIcon } from "../SourceIcon";
import { MetadataBadge } from "../MetadataBadge";
import { BookIcon, CheckmarkIcon, LightBulbIcon, XIcon } from "../icons/icons";
import { BookIcon, LightBulbIcon } from "../icons/icons";
import { FaStar } from "react-icons/fa";
import { FiTag } from "react-icons/fi";
import { DISABLE_LLM_DOC_RELEVANCE } from "@/lib/constants";
import { SettingsContext } from "../settings/SettingsProvider";
import { CustomTooltip, TooltipGroup } from "../tooltip/CustomTooltip";
import { WarningCircle } from "@phosphor-icons/react";

View File

@@ -19,9 +19,10 @@ const TooltipGroupContext = createContext<{
hoverCountRef: { current: false },
});
export const TooltipGroup: React.FC<{ children: React.ReactNode }> = ({
children,
}) => {
export const TooltipGroup: React.FC<{
children: React.ReactNode;
gap?: string;
}> = ({ children, gap = "" }) => {
const [groupHovered, setGroupHovered] = useState(false);
const hoverCountRef = useRef(false);
@@ -29,7 +30,7 @@ export const TooltipGroup: React.FC<{ children: React.ReactNode }> = ({
<TooltipGroupContext.Provider
value={{ groupHovered, setGroupHovered, hoverCountRef }}
>
<div className="inline-flex">{children}</div>
<div className={`inline-flex ${gap}`}>{children}</div>
</TooltipGroupContext.Provider>
);
};

View File

@@ -0,0 +1,56 @@
import * as React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cva, type VariantProps } from "class-variance-authority";
import { cn } from "@/lib/utils";
const buttonVariants = cva(
"inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50",
{
variants: {
variant: {
default: "bg-primary text-primary-foreground hover:bg-primary/90",
destructive:
"bg-destructive text-destructive-foreground hover:bg-destructive/90",
outline:
"border border-input bg-background hover:bg-accent hover:text-accent-foreground",
secondary:
"bg-secondary text-secondary-foreground hover:bg-secondary/80",
ghost: "hover:bg-accent hover:text-accent-foreground",
link: "text-primary underline-offset-4 hover:underline",
},
size: {
default: "h-10 px-4 py-2",
sm: "h-9 rounded-md px-3",
lg: "h-11 rounded-md px-8",
icon: "h-10 w-10",
},
},
defaultVariants: {
variant: "default",
size: "default",
},
}
);
export interface ButtonProps
extends React.ButtonHTMLAttributes<HTMLButtonElement>,
VariantProps<typeof buttonVariants> {
asChild?: boolean;
}
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
({ className, variant, size, asChild = false, ...props }, ref) => {
const Comp = asChild ? Slot : "button";
return (
<Comp
className={cn(buttonVariants({ variant, size, className }))}
ref={ref}
{...props}
/>
);
}
);
Button.displayName = "Button";
export { Button, buttonVariants };

View File

@@ -0,0 +1,86 @@
import * as React from "react";
import { cn } from "@/lib/utils";
const Card = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn(
"rounded-lg border bg-card text-card-foreground shadow-sm",
className
)}
{...props}
/>
));
Card.displayName = "Card";
const CardHeader = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("flex flex-col space-y-1.5 py-6 px-4", className)}
{...props}
/>
));
CardHeader.displayName = "CardHeader";
const CardTitle = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLHeadingElement>
>(({ className, ...props }, ref) => (
<h3
ref={ref}
className={cn(
"text-2xl font-semibold leading-none tracking-tight",
className
)}
{...props}
/>
));
CardTitle.displayName = "CardTitle";
const CardDescription = React.forwardRef<
HTMLParagraphElement,
React.HTMLAttributes<HTMLParagraphElement>
>(({ className, ...props }, ref) => (
<p
ref={ref}
className={cn("text-sm text-muted-foreground", className)}
{...props}
/>
));
CardDescription.displayName = "CardDescription";
const CardContent = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div ref={ref} className={cn("p-6 pt-0", className)} {...props} />
));
CardContent.displayName = "CardContent";
const CardFooter = React.forwardRef<
HTMLDivElement,
React.HTMLAttributes<HTMLDivElement>
>(({ className, ...props }, ref) => (
<div
ref={ref}
className={cn("flex items-center p-6 pt-0", className)}
{...props}
/>
));
CardFooter.displayName = "CardFooter";
export {
Card,
CardHeader,
CardFooter,
CardTitle,
CardDescription,
CardContent,
};

View File

@@ -0,0 +1,365 @@
"use client";
import * as React from "react";
import * as RechartsPrimitive from "recharts";
import { cn } from "@/lib/utils";
// Format: { THEME_NAME: CSS_SELECTOR }
const THEMES = { light: "", dark: ".dark" } as const;
export type ChartConfig = {
[k in string]: {
label?: React.ReactNode;
icon?: React.ComponentType;
} & (
| { color?: string; theme?: never }
| { color?: never; theme: Record<keyof typeof THEMES, string> }
);
};
type ChartContextProps = {
config: ChartConfig;
};
const ChartContext = React.createContext<ChartContextProps | null>(null);
function useChart() {
const context = React.useContext(ChartContext);
if (!context) {
throw new Error("useChart must be used within a <ChartContainer />");
}
return context;
}
const ChartContainer = React.forwardRef<
HTMLDivElement,
React.ComponentProps<"div"> & {
config: ChartConfig;
children: React.ComponentProps<
typeof RechartsPrimitive.ResponsiveContainer
>["children"];
}
>(({ id, className, children, config, ...props }, ref) => {
const uniqueId = React.useId();
const chartId = `chart-${id || uniqueId.replace(/:/g, "")}`;
return (
<ChartContext.Provider value={{ config }}>
<div
data-chart={chartId}
ref={ref}
className={cn(
"flex aspect-video justify-center text-xs [&_.recharts-cartesian-axis-tick_text]:fill-muted-foreground [&_.recharts-cartesian-grid_line[stroke='#ccc']]:stroke-border/50 [&_.recharts-curve.recharts-tooltip-cursor]:stroke-border [&_.recharts-dot[stroke='#fff']]:stroke-transparent [&_.recharts-layer]:outline-none [&_.recharts-polar-grid_[stroke='#ccc']]:stroke-border [&_.recharts-radial-bar-background-sector]:fill-muted [&_.recharts-rectangle.recharts-tooltip-cursor]:fill-muted [&_.recharts-reference-line_[stroke='#ccc']]:stroke-border [&_.recharts-sector[stroke='#fff']]:stroke-transparent [&_.recharts-sector]:outline-none [&_.recharts-surface]:outline-none",
className
)}
{...props}
>
<ChartStyle id={chartId} config={config} />
<RechartsPrimitive.ResponsiveContainer>
{children}
</RechartsPrimitive.ResponsiveContainer>
</div>
</ChartContext.Provider>
);
});
ChartContainer.displayName = "Chart";
const ChartStyle = ({ id, config }: { id: string; config: ChartConfig }) => {
const colorConfig = Object.entries(config).filter(
([_, config]) => config.theme || config.color
);
if (!colorConfig.length) {
return null;
}
return (
<style
dangerouslySetInnerHTML={{
__html: Object.entries(THEMES)
.map(
([theme, prefix]) => `
${prefix} [data-chart=${id}] {
${colorConfig
.map(([key, itemConfig]) => {
const color =
itemConfig.theme?.[theme as keyof typeof itemConfig.theme] ||
itemConfig.color;
return color ? ` --color-${key}: ${color};` : null;
})
.join("\n")}
}
`
)
.join("\n"),
}}
/>
);
};
const ChartTooltip = RechartsPrimitive.Tooltip;
const ChartTooltipContent = React.forwardRef<
HTMLDivElement,
React.ComponentProps<typeof RechartsPrimitive.Tooltip> &
React.ComponentProps<"div"> & {
hideLabel?: boolean;
hideIndicator?: boolean;
indicator?: "line" | "dot" | "dashed";
nameKey?: string;
labelKey?: string;
}
>(
(
{
active,
payload,
className,
indicator = "dot",
hideLabel = false,
hideIndicator = false,
label,
labelFormatter,
labelClassName,
formatter,
color,
nameKey,
labelKey,
},
ref
) => {
const { config } = useChart();
const tooltipLabel = React.useMemo(() => {
if (hideLabel || !payload?.length) {
return null;
}
const [item] = payload;
const key = `${labelKey || item.dataKey || item.name || "value"}`;
const itemConfig = getPayloadConfigFromPayload(config, item, key);
const value =
!labelKey && typeof label === "string"
? config[label as keyof typeof config]?.label || label
: itemConfig?.label;
if (labelFormatter) {
return (
<div className={cn("font-medium", labelClassName)}>
{labelFormatter(value, payload)}
</div>
);
}
if (!value) {
return null;
}
return <div className={cn("font-medium", labelClassName)}>{value}</div>;
}, [
label,
labelFormatter,
payload,
hideLabel,
labelClassName,
config,
labelKey,
]);
if (!active || !payload?.length) {
return null;
}
const nestLabel = payload.length === 1 && indicator !== "dot";
return (
<div
ref={ref}
className={cn(
"grid min-w-[8rem] items-start gap-1.5 rounded-lg border border-border/50 bg-background px-2.5 py-1.5 text-xs shadow-xl",
className
)}
>
{!nestLabel ? tooltipLabel : null}
<div className="grid gap-1.5">
{payload.map((item, index) => {
const key = `${nameKey || item.name || item.dataKey || "value"}`;
const itemConfig = getPayloadConfigFromPayload(config, item, key);
const indicatorColor = color || item.payload.fill || item.color;
return (
<div
key={item.dataKey}
className={cn(
"flex w-full flex-wrap items-stretch gap-2 [&>svg]:h-2.5 [&>svg]:w-2.5 [&>svg]:text-muted-foreground",
indicator === "dot" && "items-center"
)}
>
{formatter && item?.value !== undefined && item.name ? (
formatter(item.value, item.name, item, index, item.payload)
) : (
<>
{itemConfig?.icon ? (
<itemConfig.icon />
) : (
!hideIndicator && (
<div
className={cn(
"shrink-0 rounded-[2px] border-[--color-border] bg-[--color-bg]",
{
"h-2.5 w-2.5": indicator === "dot",
"w-1": indicator === "line",
"w-0 border-[1.5px] border-dashed bg-transparent":
indicator === "dashed",
"my-0.5": nestLabel && indicator === "dashed",
}
)}
style={
{
"--color-bg": indicatorColor,
"--color-border": indicatorColor,
} as React.CSSProperties
}
/>
)
)}
<div
className={cn(
"flex flex-1 justify-between leading-none",
nestLabel ? "items-end" : "items-center"
)}
>
<div className="grid gap-1.5">
{nestLabel ? tooltipLabel : null}
<span className="text-muted-foreground">
{itemConfig?.label || item.name}
</span>
</div>
{item.value && (
<span className="font-mono font-medium tabular-nums text-foreground">
{item.value.toLocaleString()}
</span>
)}
</div>
</>
)}
</div>
);
})}
</div>
</div>
);
}
);
ChartTooltipContent.displayName = "ChartTooltip";
const ChartLegend = RechartsPrimitive.Legend;
const ChartLegendContent = React.forwardRef<
HTMLDivElement,
React.ComponentProps<"div"> &
Pick<RechartsPrimitive.LegendProps, "payload" | "verticalAlign"> & {
hideIcon?: boolean;
nameKey?: string;
}
>(
(
{ className, hideIcon = false, payload, verticalAlign = "bottom", nameKey },
ref
) => {
const { config } = useChart();
if (!payload?.length) {
return null;
}
return (
<div
ref={ref}
className={cn(
"flex items-center justify-center gap-4",
verticalAlign === "top" ? "pb-3" : "pt-3",
className
)}
>
{payload.map((item) => {
const key = `${nameKey || item.dataKey || "value"}`;
const itemConfig = getPayloadConfigFromPayload(config, item, key);
return (
<div
key={item.value}
className={cn(
"flex items-center gap-1.5 [&>svg]:h-3 [&>svg]:w-3 [&>svg]:text-muted-foreground"
)}
>
{itemConfig?.icon && !hideIcon ? (
<itemConfig.icon />
) : (
<div
className="h-2 w-2 shrink-0 rounded-[2px]"
style={{
backgroundColor: item.color,
}}
/>
)}
{itemConfig?.label}
</div>
);
})}
</div>
);
}
);
ChartLegendContent.displayName = "ChartLegend";
// Helper to extract item config from a payload.
function getPayloadConfigFromPayload(
config: ChartConfig,
payload: unknown,
key: string
) {
if (typeof payload !== "object" || payload === null) {
return undefined;
}
const payloadPayload =
"payload" in payload &&
typeof payload.payload === "object" &&
payload.payload !== null
? payload.payload
: undefined;
let configLabelKey: string = key;
if (
key in payload &&
typeof payload[key as keyof typeof payload] === "string"
) {
configLabelKey = payload[key as keyof typeof payload] as string;
} else if (
payloadPayload &&
key in payloadPayload &&
typeof payloadPayload[key as keyof typeof payloadPayload] === "string"
) {
configLabelKey = payloadPayload[
key as keyof typeof payloadPayload
] as string;
}
return configLabelKey in config
? config[configLabelKey]
: config[key as keyof typeof config];
}
export {
ChartContainer,
ChartTooltip,
ChartTooltipContent,
ChartLegend,
ChartLegendContent,
ChartStyle,
};

View File

@@ -0,0 +1,25 @@
import * as React from "react";
import { cn } from "@/lib/utils";
export interface InputProps
extends React.InputHTMLAttributes<HTMLInputElement> {}
const Input = React.forwardRef<HTMLInputElement, InputProps>(
({ className, type, ...props }, ref) => {
return (
<input
type={type}
className={cn(
"flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
className
)}
ref={ref}
{...props}
/>
);
}
);
Input.displayName = "Input";
export { Input };

View File

@@ -0,0 +1,117 @@
import * as React from "react";
import { cn } from "@/lib/utils";
const Table = React.forwardRef<
HTMLTableElement,
React.HTMLAttributes<HTMLTableElement>
>(({ className, ...props }, ref) => (
<div className="relative w-full overflow-auto">
<table
ref={ref}
className={cn("w-full caption-bottom text-sm", className)}
{...props}
/>
</div>
));
Table.displayName = "Table";
const TableHeader = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<thead ref={ref} className={cn("[&_tr]:border-b", className)} {...props} />
));
TableHeader.displayName = "TableHeader";
const TableBody = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<tbody
ref={ref}
className={cn("[&_tr:last-child]:border-0", className)}
{...props}
/>
));
TableBody.displayName = "TableBody";
const TableFooter = React.forwardRef<
HTMLTableSectionElement,
React.HTMLAttributes<HTMLTableSectionElement>
>(({ className, ...props }, ref) => (
<tfoot
ref={ref}
className={cn(
"border-t bg-muted/50 font-medium [&>tr]:last:border-b-0",
className
)}
{...props}
/>
));
TableFooter.displayName = "TableFooter";
const TableRow = React.forwardRef<
HTMLTableRowElement,
React.HTMLAttributes<HTMLTableRowElement>
>(({ className, ...props }, ref) => (
<tr
ref={ref}
className={cn(
"border-b transition-colors hover:bg-muted/50 data-[state=selected]:bg-muted",
className
)}
{...props}
/>
));
TableRow.displayName = "TableRow";
const TableHead = React.forwardRef<
HTMLTableCellElement,
React.ThHTMLAttributes<HTMLTableCellElement>
>(({ className, ...props }, ref) => (
<th
ref={ref}
className={cn(
"h-12 px-4 text-left align-middle font-medium text-muted-foreground [&:has([role=checkbox])]:pr-0",
className
)}
{...props}
/>
));
TableHead.displayName = "TableHead";
const TableCell = React.forwardRef<
HTMLTableCellElement,
React.TdHTMLAttributes<HTMLTableCellElement>
>(({ className, ...props }, ref) => (
<td
ref={ref}
className={cn("p-4 align-middle [&:has([role=checkbox])]:pr-0", className)}
{...props}
/>
));
TableCell.displayName = "TableCell";
const TableCaption = React.forwardRef<
HTMLTableCaptionElement,
React.HTMLAttributes<HTMLTableCaptionElement>
>(({ className, ...props }, ref) => (
<caption
ref={ref}
className={cn("mt-4 text-sm text-muted-foreground", className)}
{...props}
/>
));
TableCaption.displayName = "TableCaption";
export {
Table,
TableHeader,
TableBody,
TableFooter,
TableHead,
TableRow,
TableCell,
TableCaption,
};

View File

@@ -30,6 +30,9 @@ export const SIDEBAR_WIDTH = `w-[350px]`;
export const LOGOUT_DISABLED =
process.env.NEXT_PUBLIC_DISABLE_LOGOUT?.toLowerCase() === "true";
export const DISABLED_CSV_DISPLAY =
process.env.NEXT_PUBLIC_DISABLE_CSV_DISPLAY?.toLowerCase() === "true";
export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
process.env.NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN?.toLowerCase() === "true";

View File

@@ -22,6 +22,8 @@ export interface AnswerPiecePacket {
export enum StreamStopReason {
CONTEXT_LENGTH = "CONTEXT_LENGTH",
CANCELLED = "CANCELLED",
FINISHED = "FINISHED",
NEW_RESPONSE = "NEW_RESPONSE",
}
export interface StreamStopInfo {

6
web/src/lib/utils.ts Normal file
View File

@@ -0,0 +1,6 @@
import { type ClassValue, clsx } from "clsx";
import { twMerge } from "tailwind-merge";
export function cn(...inputs: ClassValue[]) {
return twMerge(clsx(inputs));
}

80
web/tailwind.config.ts Normal file
View File

@@ -0,0 +1,80 @@
import type { Config } from "tailwindcss";
const config = {
darkMode: ["class"],
content: [
"./pages/**/*.{ts,tsx}",
"./components/**/*.{ts,tsx}",
"./app/**/*.{ts,tsx}",
"./src/**/*.{ts,tsx}",
],
prefix: "",
theme: {
container: {
center: true,
padding: "2rem",
screens: {
"2xl": "1400px",
},
},
extend: {
colors: {
border: "hsl(var(--border))",
input: "hsl(var(--input))",
ring: "hsl(var(--ring))",
background: "hsl(var(--background))",
foreground: "hsl(var(--foreground))",
primary: {
DEFAULT: "hsl(var(--primary))",
foreground: "hsl(var(--primary-foreground))",
},
secondary: {
DEFAULT: "hsl(var(--secondary))",
foreground: "hsl(var(--secondary-foreground))",
},
destructive: {
DEFAULT: "hsl(var(--destructive))",
foreground: "hsl(var(--destructive-foreground))",
},
muted: {
DEFAULT: "hsl(var(--muted))",
foreground: "hsl(var(--muted-foreground))",
},
accent: {
DEFAULT: "hsl(var(--accent))",
foreground: "hsl(var(--accent-foreground))",
},
popover: {
DEFAULT: "hsl(var(--popover))",
foreground: "hsl(var(--popover-foreground))",
},
card: {
DEFAULT: "hsl(var(--card))",
foreground: "hsl(var(--card-foreground))",
},
},
borderRadius: {
lg: "var(--radius)",
md: "calc(var(--radius) - 2px)",
sm: "calc(var(--radius) - 4px)",
},
keyframes: {
"accordion-down": {
from: { height: "0" },
to: { height: "var(--radix-accordion-content-height)" },
},
"accordion-up": {
from: { height: "var(--radix-accordion-content-height)" },
to: { height: "0" },
},
},
animation: {
"accordion-down": "accordion-down 0.2s ease-out",
"accordion-up": "accordion-up 0.2s ease-out",
},
},
},
plugins: [require("tailwindcss-animate")],
} satisfies Config;
export default config;