mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-01 13:45:44 +00:00
Compare commits
36 Commits
v1.0.0-clo
...
final_grap
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0d3eb28e8 | ||
|
|
d789c9ac52 | ||
|
|
d989ce13e7 | ||
|
|
3170430673 | ||
|
|
2beffdaa6e | ||
|
|
77ee061e67 | ||
|
|
532bc53a9a | ||
|
|
7b7b95703d | ||
|
|
fcc5efdaf8 | ||
|
|
1ea4a53af1 | ||
|
|
47479c8799 | ||
|
|
fbc5008259 | ||
|
|
d684fb116d | ||
|
|
2e61b374f4 | ||
|
|
15d324834f | ||
|
|
de9a9b7b6e | ||
|
|
47eb8c521d | ||
|
|
875fb05dca | ||
|
|
1285b2f4d4 | ||
|
|
842628771b | ||
|
|
7a9d5bd92e | ||
|
|
4f3b513ccb | ||
|
|
cd454dd780 | ||
|
|
9140ee99cb | ||
|
|
a64f27c895 | ||
|
|
fdf5611a35 | ||
|
|
c4f483d100 | ||
|
|
fc28c6b9e1 | ||
|
|
33e25dbd8b | ||
|
|
659e8cb69e | ||
|
|
681175e9c3 | ||
|
|
de18ec7ea4 | ||
|
|
9edbb0806d | ||
|
|
63d10e7482 | ||
|
|
ff6a15b5af | ||
|
|
49397e8a86 |
BIN
backend/aaa garp.png
Normal file
BIN
backend/aaa garp.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 27 KiB |
@@ -0,0 +1,65 @@
|
||||
"""single tool call per message
|
||||
|
||||
|
||||
Revision ID: 4e8e7ae58189
|
||||
Revises: 5c7fdadae813
|
||||
Create Date: 2024-09-09 10:07:58.008838
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4e8e7ae58189"
|
||||
down_revision = "5c7fdadae813"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create the new column
|
||||
op.add_column(
|
||||
"chat_message", sa.Column("tool_call_id", sa.Integer(), nullable=True)
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"fk_chat_message_tool_call",
|
||||
"chat_message",
|
||||
"tool_call",
|
||||
["tool_call_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
# Migrate existing data
|
||||
op.execute(
|
||||
"UPDATE chat_message SET tool_call_id = (SELECT id FROM tool_call WHERE tool_call.message_id = chat_message.id LIMIT 1)"
|
||||
)
|
||||
|
||||
# Drop the old relationship
|
||||
op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey")
|
||||
op.drop_column("tool_call", "message_id")
|
||||
|
||||
# Add a unique constraint to ensure one-to-one relationship
|
||||
op.create_unique_constraint(
|
||||
"uq_chat_message_tool_call_id", "chat_message", ["tool_call_id"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back the old column
|
||||
op.add_column(
|
||||
"tool_call",
|
||||
sa.Column("message_id", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"]
|
||||
)
|
||||
|
||||
# Migrate data back
|
||||
op.execute(
|
||||
"UPDATE tool_call SET message_id = (SELECT id FROM chat_message WHERE chat_message.tool_call_id = tool_call.id)"
|
||||
)
|
||||
|
||||
# Drop the new column
|
||||
op.drop_constraint("fk_chat_message_tool_call", "chat_message", type_="foreignkey")
|
||||
op.drop_column("chat_message", "tool_call_id")
|
||||
@@ -11,6 +11,7 @@ from danswer.search.enums import SearchType
|
||||
from danswer.search.models import RetrievalDocs
|
||||
from danswer.search.models import SearchResponse
|
||||
from danswer.tools.custom.base_tool_types import ToolResultType
|
||||
from danswer.tools.graphing.models import GraphGenerationDisplay
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
@@ -48,6 +49,8 @@ class QADocsResponse(RetrievalDocs):
|
||||
class StreamStopReason(Enum):
|
||||
CONTEXT_LENGTH = "context_length"
|
||||
CANCELLED = "cancelled"
|
||||
FINISHED = "finished"
|
||||
NEW_RESPONSE = "new_response"
|
||||
|
||||
|
||||
class StreamStopInfo(BaseModel):
|
||||
@@ -173,6 +176,7 @@ AnswerQuestionPossibleReturn = (
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| StreamingError
|
||||
| GraphGenerationDisplay
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,8 @@ from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import BING_API_KEY
|
||||
from danswer.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
@@ -72,11 +74,16 @@ from danswer.search.utils import relevant_sections_to_indices
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
|
||||
from danswer.tools.built_in_tools import get_built_in_tool_by_id
|
||||
from danswer.tools.custom.custom_tool import build_custom_tools_from_openapi_schema
|
||||
from danswer.tools.custom.custom_tool import CUSTOM_TOOL_RESPONSE_ID
|
||||
from danswer.tools.custom.custom_tool import CustomToolCallSummary
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.graphing.graphing_tool import GraphingResponse
|
||||
from danswer.tools.graphing.graphing_tool import GraphingTool
|
||||
from danswer.tools.graphing.models import GraphGenerationDisplay
|
||||
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
@@ -243,6 +250,7 @@ def _get_force_search_settings(
|
||||
ChatPacket = (
|
||||
StreamingError
|
||||
| QADocsResponse
|
||||
| GraphingResponse
|
||||
| LLMRelevanceFilterResponse
|
||||
| FinalUsedContextDocsResponse
|
||||
| ChatMessageDetail
|
||||
@@ -251,6 +259,7 @@ ChatPacket = (
|
||||
| CitationInfo
|
||||
| ImageGenerationDisplay
|
||||
| CustomToolResponse
|
||||
| GraphGenerationDisplay
|
||||
| MessageSpecificCitations
|
||||
| MessageResponseIDInfo
|
||||
)
|
||||
@@ -528,8 +537,21 @@ def stream_chat_message_objects(
|
||||
tool_dict: dict[int, list[Tool]] = {} # tool_id to tool
|
||||
for db_tool_model in persona.tools:
|
||||
# handle in-code tools specially
|
||||
|
||||
if db_tool_model.in_code_tool_id:
|
||||
tool_cls = get_built_in_tool_by_id(db_tool_model.id, db_session)
|
||||
if (
|
||||
tool_cls.__name__ == CSVAnalysisTool.__name__
|
||||
and not latest_query_files
|
||||
):
|
||||
tool_dict[db_tool_model.id] = [CSVAnalysisTool()]
|
||||
|
||||
if (
|
||||
tool_cls.__name__ == GraphingTool.__name__
|
||||
and not latest_query_files
|
||||
):
|
||||
tool_dict[db_tool_model.id] = [GraphingTool(output_dir="output")]
|
||||
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not latest_query_files:
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
@@ -600,7 +622,6 @@ def stream_chat_message_objects(
|
||||
]
|
||||
|
||||
continue
|
||||
|
||||
# handle all custom tools
|
||||
if db_tool_model.openapi_schema:
|
||||
tool_dict[db_tool_model.id] = cast(
|
||||
@@ -617,6 +638,11 @@ def stream_chat_message_objects(
|
||||
tools: list[Tool] = []
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
document_pruning_config.tool_num_tokens = compute_all_tool_tokens(
|
||||
@@ -662,86 +688,180 @@ def stream_chat_message_objects(
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
yielded_message_id_info = True
|
||||
|
||||
for packet in answer.processed_streamed_output:
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
if isinstance(packet, StreamStopInfo):
|
||||
if packet.stop_reason is not StreamStopReason.NEW_RESPONSE:
|
||||
break
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
db_citations = None
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
if reference_db_search_docs:
|
||||
db_citations = _translate_citations(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(CustomToolCallSummary, packet.response)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
# Saving Gen AI answer and responding with message info
|
||||
if tool_result is None:
|
||||
tool_call = None
|
||||
else:
|
||||
tool_call = ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments={
|
||||
k: v if not isinstance(v, bytes) else v.decode("utf-8")
|
||||
for k, v in tool_result.tool_args.items()
|
||||
},
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
message=answer.llm_answer,
|
||||
rephrased_query=cast(
|
||||
QADocsResponse, qa_docs_response
|
||||
).rephrased_query
|
||||
if qa_docs_response is not None
|
||||
else None,
|
||||
reference_docs=reference_db_search_docs,
|
||||
files=ai_message_files,
|
||||
token_count=len(llm_tokenizer_encode_func(answer.llm_answer)),
|
||||
citations=cast(MessageSpecificCitations, db_citations).citation_map
|
||||
if db_citations is not None
|
||||
else None,
|
||||
error=None,
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
db_session.commit() # actually save user / assistant message
|
||||
|
||||
msg_detail_response = translate_db_message_to_chat_message_detail(
|
||||
gen_ai_response_message
|
||||
)
|
||||
|
||||
yield msg_detail_response
|
||||
reserved_message_id = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message.id
|
||||
if user_message is not None
|
||||
else gen_ai_response_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
yielded_message_id_info = False
|
||||
|
||||
partial_response = partial(
|
||||
create_new_chat_message,
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message=gen_ai_response_message,
|
||||
prompt_id=prompt_id,
|
||||
overridden_model=overridden_model,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
alternate_assistant_id=new_msg_req.alternate_assistant_id,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
reference_db_search_docs = None
|
||||
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
if not yielded_message_id_info:
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=gen_ai_response_message.id,
|
||||
reserved_assistant_message_id=reserved_message_id,
|
||||
)
|
||||
yielded_message_id_info = True
|
||||
|
||||
if isinstance(packet, ToolResponse):
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
dropped_indices,
|
||||
) = _handle_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
dedupe_docs=retrieval_options.dedupe_docs
|
||||
if retrieval_options
|
||||
else False,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if reference_db_search_docs is not None:
|
||||
llm_indices = relevant_sections_to_indices(
|
||||
relevance_sections=relevance_sections,
|
||||
items=[
|
||||
translate_db_search_doc_to_server_search_doc(doc)
|
||||
for doc in reference_db_search_docs
|
||||
],
|
||||
)
|
||||
|
||||
if dropped_indices:
|
||||
llm_indices = drop_llm_indices(
|
||||
llm_indices=llm_indices,
|
||||
search_docs=reference_db_search_docs,
|
||||
dropped_indices=dropped_indices,
|
||||
)
|
||||
|
||||
yield LLMRelevanceFilterResponse(
|
||||
llm_selected_doc_indices=llm_indices
|
||||
)
|
||||
|
||||
elif packet.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
yield FinalUsedContextDocsResponse(
|
||||
final_context_docs=packet.response
|
||||
)
|
||||
elif packet.id == GRAPHING_RESPONSE_ID:
|
||||
graph_generation = cast(GraphingResponse, packet.response)
|
||||
yield graph_generation
|
||||
|
||||
# yield GraphGenerationDisplay(
|
||||
# file_id=graph_generation.extra_graph_display.file_id,
|
||||
# line_graph=graph_generation.extra_graph_display.line_graph,
|
||||
# )
|
||||
|
||||
elif packet.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield ImageGenerationDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
elif packet.id == INTERNET_SEARCH_RESPONSE_ID:
|
||||
(
|
||||
qa_docs_response,
|
||||
reference_db_search_docs,
|
||||
) = _handle_internet_search_tool_response_summary(
|
||||
packet=packet,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield qa_docs_response
|
||||
elif packet.id == CUSTOM_TOOL_RESPONSE_ID:
|
||||
custom_tool_response = cast(
|
||||
CustomToolCallSummary, packet.response
|
||||
)
|
||||
yield CustomToolResponse(
|
||||
response=custom_tool_response.tool_result,
|
||||
tool_name=custom_tool_response.tool_name,
|
||||
)
|
||||
else:
|
||||
if isinstance(packet, ToolCallFinalResult):
|
||||
tool_result = packet
|
||||
yield cast(ChatPacket, packet)
|
||||
|
||||
logger.debug("Reached end of stream")
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
@@ -767,11 +887,8 @@ def stream_chat_message_objects(
|
||||
)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
for tool_id, tool_list in tool_dict.items():
|
||||
for tool in tool_list:
|
||||
tool_name_to_tool_id[tool.name] = tool_id
|
||||
if answer.llm_answer == "":
|
||||
return
|
||||
|
||||
gen_ai_response_message = partial_response(
|
||||
reserved_message_id=reserved_message_id,
|
||||
@@ -786,16 +903,14 @@ def stream_chat_message_objects(
|
||||
if message_specific_citations
|
||||
else None,
|
||||
error=None,
|
||||
tool_calls=[
|
||||
ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
]
|
||||
tool_call=ToolCall(
|
||||
tool_id=tool_name_to_tool_id[tool_result.tool_name],
|
||||
tool_name=tool_result.tool_name,
|
||||
tool_arguments=tool_result.tool_args,
|
||||
tool_result=tool_result.tool_result,
|
||||
)
|
||||
if tool_result
|
||||
else [],
|
||||
else None,
|
||||
)
|
||||
|
||||
logger.debug("Committing messages")
|
||||
|
||||
@@ -135,7 +135,7 @@ POSTGRES_PASSWORD = urllib.parse.quote_plus(
|
||||
os.environ.get("POSTGRES_PASSWORD") or "password"
|
||||
)
|
||||
POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "localhost"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432"
|
||||
POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5433"
|
||||
POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres"
|
||||
|
||||
# defaults to False
|
||||
|
||||
@@ -165,6 +165,7 @@ class FileOrigin(str, Enum):
|
||||
CONNECTOR = "connector"
|
||||
GENERATED_REPORT = "generated_report"
|
||||
OTHER = "other"
|
||||
GRAPH_GEN = "graph_gen"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
|
||||
@@ -178,8 +178,14 @@ def delete_search_doc_message_relationship(
|
||||
|
||||
|
||||
def delete_tool_call_for_message_id(message_id: int, db_session: Session) -> None:
|
||||
stmt = delete(ToolCall).where(ToolCall.message_id == message_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message = (
|
||||
db_session.query(ChatMessage).filter(ChatMessage.id == message_id).first()
|
||||
)
|
||||
if chat_message and chat_message.tool_call_id:
|
||||
stmt = delete(ToolCall).where(ToolCall.id == chat_message.tool_call_id)
|
||||
db_session.execute(stmt)
|
||||
chat_message.tool_call_id = None
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -388,7 +394,7 @@ def get_chat_messages_by_session(
|
||||
)
|
||||
|
||||
if prefetch_tool_calls:
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_calls))
|
||||
stmt = stmt.options(joinedload(ChatMessage.tool_call))
|
||||
result = db_session.scalars(stmt).unique().all()
|
||||
else:
|
||||
result = db_session.scalars(stmt).all()
|
||||
@@ -474,7 +480,7 @@ def create_new_chat_message(
|
||||
alternate_assistant_id: int | None = None,
|
||||
# Maps the citation number [n] to the DB SearchDoc
|
||||
citations: dict[int, int] | None = None,
|
||||
tool_calls: list[ToolCall] | None = None,
|
||||
tool_call: ToolCall | None = None,
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
@@ -494,7 +500,7 @@ def create_new_chat_message(
|
||||
existing_message.message_type = message_type
|
||||
existing_message.citations = citations
|
||||
existing_message.files = files
|
||||
existing_message.tool_calls = tool_calls if tool_calls else []
|
||||
existing_message.tool_call = tool_call if tool_call else None
|
||||
existing_message.error = error
|
||||
existing_message.alternate_assistant_id = alternate_assistant_id
|
||||
existing_message.overridden_model = overridden_model
|
||||
@@ -513,7 +519,7 @@ def create_new_chat_message(
|
||||
message_type=message_type,
|
||||
citations=citations,
|
||||
files=files,
|
||||
tool_calls=tool_calls if tool_calls else [],
|
||||
tool_call=tool_call if tool_call else None,
|
||||
error=error,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
overridden_model=overridden_model,
|
||||
@@ -747,14 +753,13 @@ def translate_db_message_to_chat_message_detail(
|
||||
time_sent=chat_message.time_sent,
|
||||
citations=chat_message.citations,
|
||||
files=chat_message.files or [],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
alternate_assistant_id=chat_message.alternate_assistant_id,
|
||||
overridden_model=chat_message.overridden_model,
|
||||
)
|
||||
|
||||
@@ -854,10 +854,8 @@ class ToolCall(Base):
|
||||
tool_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB())
|
||||
tool_result: Mapped[JSON_ro] = mapped_column(postgresql.JSONB())
|
||||
|
||||
message_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
|
||||
message: Mapped["ChatMessage"] = relationship(
|
||||
"ChatMessage", back_populates="tool_calls"
|
||||
"ChatMessage", back_populates="tool_call"
|
||||
)
|
||||
|
||||
|
||||
@@ -984,9 +982,14 @@ class ChatMessage(Base):
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_calls: Mapped[list["ToolCall"]] = relationship(
|
||||
"ToolCall",
|
||||
back_populates="message",
|
||||
|
||||
tool_call_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("tool_call.id"), nullable=True
|
||||
)
|
||||
# NOTE: Should always be attached to the `assistant` message.
|
||||
# represents the tool calls used to generate this message
|
||||
tool_call: Mapped["ToolCall"] = relationship(
|
||||
"ToolCall", back_populates="message", foreign_keys=[tool_call_id]
|
||||
)
|
||||
standard_answers: Mapped[list["StandardAnswer"]] = relationship(
|
||||
"StandardAnswer",
|
||||
|
||||
@@ -13,6 +13,8 @@ class ChatFileType(str, Enum):
|
||||
DOC = "document"
|
||||
# Plain text only contain the text
|
||||
PLAIN_TEXT = "plain_text"
|
||||
# csv types contain the binary data
|
||||
CSV = "csv"
|
||||
|
||||
|
||||
class FileDescriptor(TypedDict):
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
@@ -16,6 +17,27 @@ from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
def save_base64_image(base64_image: str) -> str:
|
||||
with get_session_context_manager() as db_session:
|
||||
if base64_image.startswith("data:image"):
|
||||
base64_image = base64_image.split(",", 1)[1]
|
||||
|
||||
image_data = base64.b64decode(base64_image)
|
||||
|
||||
unique_id = str(uuid4())
|
||||
|
||||
file_io = BytesIO(image_data)
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=unique_id,
|
||||
content=file_io,
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type="image/png",
|
||||
)
|
||||
return unique_id
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
file_descriptor: FileDescriptor, db_session: Session
|
||||
) -> InMemoryChatFile:
|
||||
|
||||
@@ -16,6 +16,7 @@ from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.configs.chat_configs import QA_PROMPT_OVERRIDE
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
@@ -40,11 +41,13 @@ from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
|
||||
from danswer.tools.custom.custom_tool_prompt_builder import (
|
||||
build_user_message_for_custom_tool_for_non_tool_calling_llm,
|
||||
)
|
||||
from danswer.tools.force import filter_tools_for_force_tool_use
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.graphing.graphing_tool import GraphingTool
|
||||
from danswer.tools.images.image_generation_tool import IMAGE_GENERATION_RESPONSE_ID
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationResponse
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
@@ -68,6 +71,7 @@ from danswer.tools.tool_runner import ToolRunner
|
||||
from danswer.tools.tool_selection import select_single_tool_for_non_tool_calling_llm
|
||||
from danswer.tools.utils import explicit_tool_calling_supported
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import MAX_TOOL_CALLS
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -161,6 +165,10 @@ class Answer:
|
||||
self.skip_gen_ai_answer_generation = skip_gen_ai_answer_generation
|
||||
self._is_cancelled = False
|
||||
|
||||
self.final_context_docs: list = []
|
||||
self.current_streamed_output: list = []
|
||||
self.processing_stream: list = []
|
||||
|
||||
def _update_prompt_builder_for_search_tool(
|
||||
self, prompt_builder: AnswerPromptBuilder, final_context_documents: list[LlmDoc]
|
||||
) -> None:
|
||||
@@ -196,41 +204,50 @@ class Answer:
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
tool_calls = 0
|
||||
initiated = False
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
# if we are forcing a tool WITH args specified, we don't need to check which tools to run
|
||||
# / need to generate the args
|
||||
tool_call_chunk = AIMessageChunk(
|
||||
content="",
|
||||
)
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
else:
|
||||
# if tool calling is supported, first try the raw message
|
||||
# to see if we don't need to use any tools
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
|
||||
if initiated:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
initiated = True
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
|
||||
tool_call_chunk: AIMessageChunk | None = None
|
||||
|
||||
if self.force_use_tool.force_use and self.force_use_tool.args is not None:
|
||||
tool_call_chunk = AIMessageChunk(content="")
|
||||
tool_call_chunk.tool_calls = [
|
||||
{
|
||||
"name": self.force_use_tool.tool_name,
|
||||
"args": self.force_use_tool.args,
|
||||
"id": str(uuid4()),
|
||||
}
|
||||
]
|
||||
|
||||
else:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question,
|
||||
self.prompt_config,
|
||||
self.latest_query_files,
|
||||
tool_calls,
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
final_tool_definitions = [
|
||||
tool.tool_definition()
|
||||
for tool in filter_tools_for_force_tool_use(
|
||||
self.tools, self.force_use_tool
|
||||
)
|
||||
]
|
||||
|
||||
print(final_tool_definitions)
|
||||
for message in self.llm.stream(
|
||||
prompt=prompt,
|
||||
tools=final_tool_definitions if final_tool_definitions else None,
|
||||
@@ -257,67 +274,129 @@ class Answer:
|
||||
)
|
||||
|
||||
if not tool_call_chunk:
|
||||
return # no tool call needed
|
||||
logger.info("Skipped tool call but generated message")
|
||||
return
|
||||
|
||||
# if we have a tool call, we need to call the tool
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
for tool_call_request in tool_call_requests:
|
||||
known_tools_by_name = [
|
||||
tool for tool in self.tools if tool.name == tool_call_request["name"]
|
||||
]
|
||||
tool_call_requests = tool_call_chunk.tool_calls
|
||||
print(tool_call_requests)
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
for tool_call_request in tool_call_requests:
|
||||
tool_calls += 1
|
||||
known_tools_by_name = [
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == tool_call_request["name"]
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
|
||||
if not known_tools_by_name:
|
||||
logger.error(
|
||||
"Tool call requested with unknown name field. \n"
|
||||
f"self.tools: {self.tools}"
|
||||
f"tool_call_request: {tool_call_request}"
|
||||
)
|
||||
if self.tools:
|
||||
tool = self.tools[0]
|
||||
else:
|
||||
continue
|
||||
else:
|
||||
tool = known_tools_by_name[0]
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.tool_name == tool.name
|
||||
and self.force_use_tool.args
|
||||
else tool_call_request["args"]
|
||||
)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
yield tool_kickoff
|
||||
|
||||
yield from tool_runner.tool_responses()
|
||||
|
||||
tool_responses = []
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
tool_call_summary = ToolCallSummary(
|
||||
tool_call_request=tool_call_chunk,
|
||||
tool_call_result=build_tool_message(
|
||||
tool_call_request, tool_runner.tool_message_content()
|
||||
),
|
||||
)
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
self._update_prompt_builder_for_search_tool(prompt_builder, [])
|
||||
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = [
|
||||
img_generation_result["url"]
|
||||
for img_generation_result in tool_runner.tool_final_result().tool_result
|
||||
]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question, img_urls=img_urls
|
||||
)
|
||||
)
|
||||
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
# Update message history with tool call and response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=self.question,
|
||||
message_type=MessageType.USER,
|
||||
token_count=len(self.llm_tokenizer.encode(self.question)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
yield tool_runner.tool_final_result()
|
||||
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=str(tool_call_request),
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(str(tool_call_request))
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message="\n".join(str(response) for response in tool_responses),
|
||||
message_type=MessageType.SYSTEM,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(
|
||||
"\n".join(str(response) for response in tool_responses)
|
||||
)
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=[tool.tool_definition() for tool in self.tools],
|
||||
)
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build(tool_call_summary=tool_call_summary)
|
||||
|
||||
return
|
||||
response_content = ""
|
||||
for content in self._process_llm_stream(prompt=prompt, tools=None):
|
||||
if isinstance(content, str):
|
||||
response_content += content
|
||||
yield content
|
||||
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=response_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(self.llm_tokenizer.encode(response_content)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
|
||||
# This method processes the LLM stream and yields the content or stop information
|
||||
def _process_llm_stream(
|
||||
@@ -346,139 +425,234 @@ class Answer:
|
||||
) -> Iterator[
|
||||
str | StreamStopInfo | ToolCallKickoff | ToolResponse | ToolCallFinalResult
|
||||
]:
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
tool_calls = 0
|
||||
initiated = False
|
||||
while tool_calls < (1 if self.force_use_tool.force_use else MAX_TOOL_CALLS):
|
||||
if initiated:
|
||||
yield StreamStopInfo(stop_reason=StreamStopReason.NEW_RESPONSE)
|
||||
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(f"Tool '{self.force_use_tool.tool_name}' not found")
|
||||
initiated = True
|
||||
prompt_builder = AnswerPromptBuilder(self.message_history, self.llm.config)
|
||||
chosen_tool_and_args: tuple[Tool, dict] | None = None
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
if self.force_use_tool.force_use:
|
||||
# if we are forcing a tool, we don't need to check which tools to run
|
||||
tool = next(
|
||||
iter(
|
||||
[
|
||||
tool
|
||||
for tool in self.tools
|
||||
if tool.name == self.force_use_tool.tool_name
|
||||
]
|
||||
),
|
||||
None,
|
||||
)
|
||||
if not tool:
|
||||
raise RuntimeError(
|
||||
f"Tool '{self.force_use_tool.tool_name}' not found"
|
||||
)
|
||||
|
||||
tool_args = (
|
||||
self.force_use_tool.args
|
||||
if self.force_use_tool.args is not None
|
||||
else tool.get_args_for_non_tool_calling_llm(
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
force_run=True,
|
||||
)
|
||||
)
|
||||
|
||||
if tool_args is None:
|
||||
raise RuntimeError(f"Tool '{tool.name}' did not return args")
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
chosen_tool_and_args = (tool, tool_args)
|
||||
else:
|
||||
tool_options = check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
tools=self.tools,
|
||||
query=self.question,
|
||||
history=self.message_history,
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
available_tools_and_args = [
|
||||
(self.tools[ind], args)
|
||||
for ind, args in enumerate(tool_options)
|
||||
if args is not None
|
||||
]
|
||||
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
logger.info(
|
||||
f"Selecting single tool from tools: {[(tool.name, args) for tool, args in available_tools_and_args]}"
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
chosen_tool_and_args = (
|
||||
select_single_tool_for_non_tool_calling_llm(
|
||||
tools_and_args=available_tools_and_args,
|
||||
history=self.message_history,
|
||||
query=self.question,
|
||||
llm=self.llm,
|
||||
)
|
||||
if available_tools_and_args
|
||||
else None
|
||||
)
|
||||
|
||||
logger.notice(f"Chosen tool: {chosen_tool_and_args}")
|
||||
|
||||
if not chosen_tool_and_args:
|
||||
prompt_builder.update_system_prompt(
|
||||
default_build_system_message(self.prompt_config)
|
||||
)
|
||||
prompt_builder.update_user_prompt(
|
||||
default_build_user_message(
|
||||
self.question, self.prompt_config, self.latest_query_files
|
||||
)
|
||||
)
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
return
|
||||
tool_calls += 1
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
print("tool args")
|
||||
print(tool_args)
|
||||
|
||||
tool_runner = ToolRunner(tool, tool_args, self.llm)
|
||||
yield tool_runner.kickoff()
|
||||
tool_responses = []
|
||||
file_name = tool_runner.args["filename"]
|
||||
|
||||
print(f"file ame is {file_name}")
|
||||
|
||||
csv_file = None
|
||||
for message in self.message_history:
|
||||
if message.files:
|
||||
csv_file = next(
|
||||
(file for file in message.files if file.filename == file_name),
|
||||
None,
|
||||
)
|
||||
if csv_file:
|
||||
break
|
||||
print(self.latest_query_files)
|
||||
if csv_file is None:
|
||||
raise ValueError(
|
||||
f"CSV file with name '{file_name}' not found in latest query files."
|
||||
)
|
||||
print("csv file found")
|
||||
|
||||
tool_runner.args["filename"] = csv_file.content
|
||||
|
||||
for response in tool_runner.tool_responses():
|
||||
tool_responses.append(response)
|
||||
yield response
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
# img_urls = [img.url for img in img_generation_response]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=[img.url for img in img_generation_response],
|
||||
)
|
||||
)
|
||||
|
||||
yield response
|
||||
elif tool.name == CSVAnalysisTool._NAME:
|
||||
for response in tool_runner.tool_responses():
|
||||
yield response
|
||||
|
||||
elif tool.name == GraphingTool._NAME:
|
||||
for response in tool_runner.tool_responses():
|
||||
print("RESOS")
|
||||
print(response)
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
# img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final_result = tool_runner.tool_final_result()
|
||||
yield final_result
|
||||
|
||||
# Update message history
|
||||
self.message_history.extend(
|
||||
[
|
||||
PreviousMessage(
|
||||
message=str(self.question),
|
||||
message_type=MessageType.USER,
|
||||
token_count=len(self.llm_tokenizer.encode(str(self.question))),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
PreviousMessage(
|
||||
message=f"Tool used: {tool.name}",
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(
|
||||
self.llm_tokenizer.encode(f"Tool used: {tool.name}")
|
||||
),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
PreviousMessage(
|
||||
message=str(final_result),
|
||||
message_type=MessageType.SYSTEM,
|
||||
token_count=len(self.llm_tokenizer.encode(str(final_result))),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
# Generate response based on updated message history
|
||||
prompt = prompt_builder.build()
|
||||
yield from self._process_llm_stream(
|
||||
prompt=prompt,
|
||||
tools=None,
|
||||
)
|
||||
return
|
||||
|
||||
tool, tool_args = chosen_tool_and_args
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
yield tool_runner.kickoff()
|
||||
response_content = ""
|
||||
for content in self._process_llm_stream(prompt=prompt, tools=None):
|
||||
if isinstance(content, str):
|
||||
response_content += content
|
||||
yield content
|
||||
|
||||
if tool.name in {SearchTool._NAME, InternetSearchTool._NAME}:
|
||||
final_context_documents = None
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_context_documents = cast(list[LlmDoc], response.response)
|
||||
yield response
|
||||
|
||||
if final_context_documents is None:
|
||||
raise RuntimeError(
|
||||
f"{tool.name} did not return final context documents"
|
||||
)
|
||||
|
||||
self._update_prompt_builder_for_search_tool(
|
||||
prompt_builder, final_context_documents
|
||||
)
|
||||
elif tool.name == ImageGenerationTool._NAME:
|
||||
img_urls = []
|
||||
for response in tool_runner.tool_responses():
|
||||
if response.id == IMAGE_GENERATION_RESPONSE_ID:
|
||||
img_generation_response = cast(
|
||||
list[ImageGenerationResponse], response.response
|
||||
)
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
|
||||
yield response
|
||||
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=self.question,
|
||||
img_urls=img_urls,
|
||||
# Update message history with LLM response
|
||||
self.message_history.append(
|
||||
PreviousMessage(
|
||||
message=response_content,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
token_count=len(self.llm_tokenizer.encode(response_content)),
|
||||
tool_call=None,
|
||||
files=[],
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt_builder.update_user_prompt(
|
||||
HumanMessage(
|
||||
content=build_user_message_for_custom_tool_for_non_tool_calling_llm(
|
||||
self.question,
|
||||
tool.name,
|
||||
*tool_runner.tool_responses(),
|
||||
)
|
||||
)
|
||||
)
|
||||
final = tool_runner.tool_final_result()
|
||||
|
||||
yield final
|
||||
|
||||
prompt = prompt_builder.build()
|
||||
|
||||
yield from self._process_llm_stream(prompt=prompt, tools=None)
|
||||
|
||||
@property
|
||||
def processed_streamed_output(self) -> AnswerStream:
|
||||
@@ -495,6 +669,8 @@ class Answer:
|
||||
else self._raw_output_for_non_explicit_tool_calling_llms()
|
||||
)
|
||||
|
||||
self.processing_stream = []
|
||||
|
||||
def _process_stream(
|
||||
stream: Iterator[ToolCallKickoff | ToolResponse | str | StreamStopInfo],
|
||||
) -> AnswerStream:
|
||||
@@ -535,56 +711,70 @@ class Answer:
|
||||
|
||||
yield message
|
||||
else:
|
||||
# assumes all tool responses will come first, then the final answer
|
||||
break
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
|
||||
if not self.skip_gen_ai_answer_generation:
|
||||
process_answer_stream_fn = _get_answer_stream_processor(
|
||||
context_docs=final_context_docs or [],
|
||||
# if doc selection is enabled, then search_results will be None,
|
||||
# so we need to use the final_context_docs
|
||||
doc_id_to_rank_map=map_document_id_order(
|
||||
search_results or final_context_docs or []
|
||||
),
|
||||
answer_style_configs=self.answer_style_config,
|
||||
)
|
||||
stream_stop_info = None
|
||||
new_kickoff = None
|
||||
|
||||
stream_stop_info = None
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
nonlocal new_kickoff
|
||||
|
||||
def _stream() -> Iterator[str]:
|
||||
nonlocal stream_stop_info
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
yield cast(str, item)
|
||||
yield cast(str, message)
|
||||
for item in stream:
|
||||
if isinstance(item, StreamStopInfo):
|
||||
stream_stop_info = item
|
||||
return
|
||||
if isinstance(item, ToolCallKickoff):
|
||||
new_kickoff = item
|
||||
return
|
||||
else:
|
||||
yield cast(str, item)
|
||||
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
yield from process_answer_stream_fn(_stream())
|
||||
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
if stream_stop_info:
|
||||
yield stream_stop_info
|
||||
|
||||
# handle new tool call (continuation of message)
|
||||
if new_kickoff:
|
||||
yield new_kickoff
|
||||
|
||||
processed_stream = []
|
||||
for processed_packet in _process_stream(output_generator):
|
||||
processed_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
if (
|
||||
isinstance(processed_packet, StreamStopInfo)
|
||||
and processed_packet.stop_reason == StreamStopReason.NEW_RESPONSE
|
||||
):
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self.processing_stream = []
|
||||
|
||||
self._processed_stream = processed_stream
|
||||
self.processing_stream.append(processed_packet)
|
||||
yield processed_packet
|
||||
self.current_streamed_output = self.processing_stream
|
||||
self._processed_stream = self.processing_stream
|
||||
|
||||
@property
|
||||
def llm_answer(self) -> str:
|
||||
answer = ""
|
||||
for packet in self.processed_streamed_output:
|
||||
if not self._processed_stream and not self.current_streamed_output:
|
||||
return ""
|
||||
for packet in self.current_streamed_output or self._processed_stream or []:
|
||||
if isinstance(packet, DanswerAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
|
||||
return answer
|
||||
|
||||
@property
|
||||
def citations(self) -> list[CitationInfo]:
|
||||
citations: list[CitationInfo] = []
|
||||
for packet in self.processed_streamed_output:
|
||||
for packet in self.current_streamed_output:
|
||||
if isinstance(packet, CitationInfo):
|
||||
citations.append(packet)
|
||||
|
||||
@@ -599,5 +789,4 @@ class Answer:
|
||||
if not self.is_connected():
|
||||
logger.debug("Answer stream has been cancelled")
|
||||
self._is_cancelled = not self.is_connected()
|
||||
|
||||
return self._is_cancelled
|
||||
|
||||
@@ -33,7 +33,7 @@ class PreviousMessage(BaseModel):
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
@@ -51,14 +51,13 @@ class PreviousMessage(BaseModel):
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_calls=[
|
||||
ToolCallFinalResult(
|
||||
tool_name=tool_call.tool_name,
|
||||
tool_args=tool_call.tool_arguments,
|
||||
tool_result=tool_call.tool_result,
|
||||
)
|
||||
for tool_call in chat_message.tool_calls
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
|
||||
@@ -36,7 +36,10 @@ def default_build_system_message(
|
||||
|
||||
|
||||
def default_build_user_message(
|
||||
user_query: str, prompt_config: PromptConfig, files: list[InMemoryChatFile] = []
|
||||
user_query: str,
|
||||
prompt_config: PromptConfig,
|
||||
files: list[InMemoryChatFile] = [],
|
||||
previous_tool_calls: int = 0,
|
||||
) -> HumanMessage:
|
||||
user_prompt = (
|
||||
CHAT_USER_CONTEXT_FREE_PROMPT.format(
|
||||
@@ -45,6 +48,12 @@ def default_build_user_message(
|
||||
if prompt_config.task_prompt
|
||||
else user_query
|
||||
)
|
||||
if previous_tool_calls > 0:
|
||||
user_prompt = (
|
||||
f"You have already generated the above so do not call a tool if not necessary. "
|
||||
f"Remember the query is: `{user_prompt}`"
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
user_msg = HumanMessage(
|
||||
content=build_content_with_imgs(user_prompt, files) if files else user_prompt
|
||||
@@ -113,25 +122,25 @@ class AnswerPromptBuilder:
|
||||
|
||||
final_messages_with_tokens.append(self.user_message_and_token_cnt)
|
||||
|
||||
if tool_call_summary:
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_request,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_request,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
final_messages_with_tokens.append(
|
||||
(
|
||||
tool_call_summary.tool_call_result,
|
||||
check_message_tokens(
|
||||
tool_call_summary.tool_call_result,
|
||||
self.llm_tokenizer_encode_func,
|
||||
),
|
||||
)
|
||||
)
|
||||
# if tool_call_summary:
|
||||
# final_messages_with_tokens.append(
|
||||
# (
|
||||
# tool_call_summary.tool_call_request,
|
||||
# check_message_tokens(
|
||||
# tool_call_summary.tool_call_request,
|
||||
# self.llm_tokenizer_encode_func,
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
# final_messages_with_tokens.append(
|
||||
# (
|
||||
# tool_call_summary.tool_call_result,
|
||||
# check_message_tokens(
|
||||
# tool_call_summary.tool_call_result,
|
||||
# self.llm_tokenizer_encode_func,
|
||||
# ),
|
||||
# )
|
||||
# )
|
||||
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
|
||||
@@ -29,6 +29,7 @@ from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.interfaces import ToolChoiceOptions
|
||||
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -98,6 +99,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
@@ -124,12 +126,21 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"name": message.name,
|
||||
}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"name": message.name or "",
|
||||
"content": message.content,
|
||||
}
|
||||
if message.id == GRAPHING_RESPONSE_ID:
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"name": message.name or "",
|
||||
"content": "a graph",
|
||||
}
|
||||
else:
|
||||
message_dict = {
|
||||
"tool_call_id": message.tool_call_id,
|
||||
"role": "tool",
|
||||
"name": message.name or "",
|
||||
"content": "a graph",
|
||||
}
|
||||
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
|
||||
@@ -112,7 +112,7 @@ def translate_danswer_msg_to_langchain(
|
||||
content = build_content_with_imgs(msg.message, files)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
return SystemMessage(content=content)
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
@@ -133,6 +133,21 @@ def translate_history_to_basemessages(
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
|
||||
def _process_csv_file(file: InMemoryChatFile) -> str:
|
||||
import pandas as pd
|
||||
import io
|
||||
|
||||
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
|
||||
csv_preview = df.head().to_string()
|
||||
|
||||
file_name_section = (
|
||||
f"CSV FILE NAME: {file.filename}\n"
|
||||
if file.filename
|
||||
else "CSV FILE (NO NAME PROVIDED):\n"
|
||||
)
|
||||
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
|
||||
|
||||
|
||||
def _build_content(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
@@ -143,16 +158,26 @@ def _build_content(
|
||||
if files
|
||||
else None
|
||||
)
|
||||
if not text_files:
|
||||
csv_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.CSV]
|
||||
if files
|
||||
else None
|
||||
)
|
||||
|
||||
if not text_files and not csv_files:
|
||||
return message
|
||||
|
||||
final_message_with_files = "FILES:\n\n"
|
||||
for file in text_files:
|
||||
for file in text_files or []:
|
||||
file_content = file.content.decode("utf-8")
|
||||
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
|
||||
final_message_with_files += (
|
||||
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
|
||||
)
|
||||
|
||||
for file in csv_files or []:
|
||||
final_message_with_files += _process_csv_file(file)
|
||||
|
||||
final_message_with_files += message
|
||||
|
||||
return final_message_with_files
|
||||
|
||||
@@ -39,6 +39,46 @@ CHAT_USER_CONTEXT_FREE_PROMPT = f"""
|
||||
""".strip()
|
||||
|
||||
|
||||
GRAPHING_QUERY_REPHRASE_GRAPH = f"""
|
||||
Given the following conversation and a follow-up input, generate Python code using matplotlib to create the requested graph.
|
||||
IMPORTANT: The code should be complete and executable, including data generation if necessary.
|
||||
Focus on creating a clear and informative visualization based on the user's request.
|
||||
Guidelines:
|
||||
- Import matplotlib.pyplot as plt at the beginning of the code.
|
||||
- Generate sample data if specific data is not provided in the conversation.
|
||||
- Use appropriate graph types (line, bar, scatter, pie) based on the nature of the data and request.
|
||||
- Include proper labeling (title, x-axis, y-axis, legend) in the graph.
|
||||
- Use plt.figure() to create the figure and assign it to a variable named 'fig'.
|
||||
- Do not include plt.show() at the end of the code.
|
||||
{GENERAL_SEP_PAT}
|
||||
Chat History:
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
Follow Up Input: {{question}}
|
||||
Python Code for Graph (Respond with only the Python code to generate the graph):
|
||||
```python
|
||||
# Your code here
|
||||
```
|
||||
""".strip()
|
||||
|
||||
|
||||
GRAPHING_GET_FILE_NAME_PROMPT = f"""
|
||||
Given the following conversation, a follow-up input, and a list of available CSV files,
|
||||
provide the name of the CSV file to analyze.
|
||||
|
||||
{GENERAL_SEP_PAT}
|
||||
Chat History:
|
||||
{{chat_history}}
|
||||
{GENERAL_SEP_PAT}
|
||||
Follow Up Input: {{question}}
|
||||
{GENERAL_SEP_PAT}
|
||||
Available CSV Files:
|
||||
{{file_list}}
|
||||
{GENERAL_SEP_PAT}
|
||||
CSV File Name to Analyze:
|
||||
"""
|
||||
|
||||
|
||||
# Design considerations for the below:
|
||||
# - In case of uncertainty, favor yes search so place the "yes" sections near the start of the
|
||||
# prompt and after the no section as well to deemphasize the no section
|
||||
|
||||
@@ -122,6 +122,25 @@ def upload_file(
|
||||
return {"file_id": file_id}
|
||||
|
||||
|
||||
@admin_router.post("/upload-csv")
|
||||
def upload_csv(
|
||||
file: UploadFile,
|
||||
db_session: Session = Depends(get_session),
|
||||
_: User | None = Depends(current_user),
|
||||
) -> dict[str, str]:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_type = ChatFileType.CSV
|
||||
file_id = str(uuid.uuid4())
|
||||
file_store.save_file(
|
||||
file_name=file_id,
|
||||
content=file.file,
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=file.content_type or file_type.value,
|
||||
)
|
||||
return {"file_id": file_id}
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
|
||||
@@ -281,14 +281,17 @@ async def is_disconnected(request: Request) -> Callable[[], bool]:
|
||||
def is_disconnected_sync() -> bool:
|
||||
future = asyncio.run_coroutine_threadsafe(request.is_disconnected(), main_loop)
|
||||
try:
|
||||
return not future.result(timeout=0.01)
|
||||
result = not future.result(timeout=0.01)
|
||||
return result
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("Asyncio timed out")
|
||||
logger.error("Asyncio timed out while checking client connection")
|
||||
return True
|
||||
except asyncio.CancelledError:
|
||||
return True
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
logger.critical(
|
||||
f"An unexpected error occured with the disconnect check coroutine: {error_msg}"
|
||||
f"An unexpected error occurred with the disconnect check coroutine: {error_msg}"
|
||||
)
|
||||
return True
|
||||
|
||||
@@ -517,7 +520,6 @@ def upload_files_for_chat(
|
||||
image_content_types = {"image/jpeg", "image/png", "image/webp"}
|
||||
text_content_types = {
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/x-config",
|
||||
@@ -527,6 +529,9 @@ def upload_files_for_chat(
|
||||
"text/xml",
|
||||
"application/x-yaml",
|
||||
}
|
||||
csv_content_types = {
|
||||
"text/csv",
|
||||
}
|
||||
document_content_types = {
|
||||
"application/pdf",
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
@@ -536,8 +541,10 @@ def upload_files_for_chat(
|
||||
"application/epub+zip",
|
||||
}
|
||||
|
||||
allowed_content_types = image_content_types.union(text_content_types).union(
|
||||
document_content_types
|
||||
allowed_content_types = (
|
||||
image_content_types.union(text_content_types)
|
||||
.union(document_content_types)
|
||||
.union(csv_content_types)
|
||||
)
|
||||
|
||||
for file in files:
|
||||
@@ -545,8 +552,12 @@ def upload_files_for_chat(
|
||||
if file.content_type in image_content_types:
|
||||
error_detail = "Unsupported image file type. Supported image types include .jpg, .jpeg, .png, .webp."
|
||||
elif file.content_type in text_content_types:
|
||||
error_detail = "Unsupported text file type. Supported text types include .txt, .csv, .md, .mdx, .conf, "
|
||||
error_detail = "Unsupported text file type. Supported text types include .txt, .md, .mdx, .conf, "
|
||||
".log, .tsv."
|
||||
elif file.content_type in csv_content_types:
|
||||
error_detail = (
|
||||
"Unsupported csv file type. Supported CSV types include .csv "
|
||||
)
|
||||
else:
|
||||
error_detail = (
|
||||
"Unsupported document file type. Supported document types include .pdf, .docx, .pptx, .xlsx, "
|
||||
@@ -572,6 +583,8 @@ def upload_files_for_chat(
|
||||
file_type = ChatFileType.IMAGE
|
||||
elif file.content_type in document_content_types:
|
||||
file_type = ChatFileType.DOC
|
||||
elif file.content_type in csv_content_types:
|
||||
file_type = ChatFileType.CSV
|
||||
else:
|
||||
file_type = ChatFileType.PLAIN_TEXT
|
||||
|
||||
@@ -584,6 +597,7 @@ def upload_files_for_chat(
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type=file.content_type or file_type.value,
|
||||
)
|
||||
print(f"FILE TYPE IS {file_type}")
|
||||
|
||||
# if the file is a doc, extract text and store that so we don't need
|
||||
# to re-extract it every time we send a message
|
||||
|
||||
@@ -178,7 +178,7 @@ class ChatMessageDetail(BaseModel):
|
||||
chat_session_id: int | None = None
|
||||
citations: dict[int, int] | None = None
|
||||
files: list[FileDescriptor]
|
||||
tool_calls: list[ToolCallFinalResult]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore
|
||||
initial_dict = super().model_dump(mode="json", *args, **kwargs) # type: ignore
|
||||
|
||||
@@ -1,19 +1,22 @@
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
|
||||
class DateTimeEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings."""
|
||||
class DateTimeAndBytesEncoder(json.JSONEncoder):
|
||||
"""Custom JSON encoder that converts datetime objects to ISO format strings and bytes to base64."""
|
||||
|
||||
def default(self, obj: Any) -> Any:
|
||||
if isinstance(obj, datetime):
|
||||
return obj.isoformat()
|
||||
elif isinstance(obj, bytes):
|
||||
return base64.b64encode(obj).decode("utf-8")
|
||||
return super().default(obj)
|
||||
|
||||
|
||||
def get_json_line(
|
||||
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeEncoder
|
||||
json_dict: dict[str, Any], encoder: type[json.JSONEncoder] = DateTimeAndBytesEncoder
|
||||
) -> str:
|
||||
"""
|
||||
Convert a dictionary to a JSON string with datetime handling, and add a newline.
|
||||
|
||||
161
backend/danswer/tools/analysis/analysis_tool.py
Normal file
161
backend/danswer/tools/analysis/analysis_tool.py
Normal file
@@ -0,0 +1,161 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
CSV_ANALYSIS_RESPONSE_ID = "csv_analysis_response"
|
||||
|
||||
YES_ANALYSIS = "Yes Analysis"
|
||||
SKIP_ANALYSIS = "Skip Analysis"
|
||||
|
||||
ANALYSIS_TEMPLATE = f"""
|
||||
Given the conversation history and a follow up query,
|
||||
determine if the system should analyze a CSV file to better answer the latest user input.
|
||||
Your default response is {SKIP_ANALYSIS}.
|
||||
Respond "{YES_ANALYSIS}" if:
|
||||
- The user is asking about the structure or content of a CSV file.
|
||||
- The user explicitly requests information about a data file.
|
||||
Conversation History:
|
||||
{{chat_history}}
|
||||
If you are at all unsure, respond with {SKIP_ANALYSIS}.
|
||||
Respond with EXACTLY and ONLY "{YES_ANALYSIS}" or "{SKIP_ANALYSIS}"
|
||||
Follow Up Input:
|
||||
{{final_query}}
|
||||
""".strip()
|
||||
|
||||
system_message = """
|
||||
You analyze CSV files by examining their structure and content.
|
||||
Your analysis should include:
|
||||
1. Number of columns and their names
|
||||
2. Data types of each column
|
||||
3. First few rows of data
|
||||
4. Basic statistics (if applicable)
|
||||
Provide a concise summary of the file's content and structure.
|
||||
"""
|
||||
|
||||
|
||||
class CSVAnalysisTool(Tool):
|
||||
_NAME = "analyze_csv"
|
||||
_DISPLAY_NAME = "CSV Analysis Tool"
|
||||
_DESCRIPTION = system_message
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"file_path": {
|
||||
"type": "string",
|
||||
"description": "Path to the CSV file to analyze",
|
||||
},
|
||||
},
|
||||
"required": ["file_path"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def check_if_needs_analysis(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
) -> bool:
|
||||
history_str = "\n".join([f"{m.message}" for m in history])
|
||||
prompt = ANALYSIS_TEMPLATE.format(
|
||||
chat_history=history_str,
|
||||
final_query=query,
|
||||
)
|
||||
use_analysis_output = llm.invoke(prompt)
|
||||
|
||||
logger.debug(f"Evaluated if should use CSV analysis: {use_analysis_output}")
|
||||
content = use_analysis_output.content
|
||||
print(content)
|
||||
|
||||
return YES_ANALYSIS.lower() in str(content).lower()
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
if not force_run and not self.check_if_needs_analysis(query, history, llm):
|
||||
return None
|
||||
|
||||
return {
|
||||
"prompt": query,
|
||||
}
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
graph_response = next(arg for arg in args if arg.id == CSV_ANALYSIS_RESPONSE_ID)
|
||||
return json.dumps(graph_response.response.dict())
|
||||
|
||||
def run(
|
||||
self, llm: LLM | None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
if llm is not None:
|
||||
logger.warning("LLM passed to CSVAnalysisTool.run() but not used")
|
||||
|
||||
file_path = kwargs["file_path"]
|
||||
|
||||
try:
|
||||
# Read the first few rows of the CSV file
|
||||
df = pd.read_csv(file_path, nrows=5)
|
||||
|
||||
# Analyze the structure and content
|
||||
analysis = {
|
||||
"num_columns": len(df.columns),
|
||||
"column_names": df.columns.tolist(),
|
||||
"data_types": df.dtypes.astype(str).tolist(),
|
||||
"first_rows": df.to_dict(orient="records"),
|
||||
"basic_stats": df.describe().to_dict(),
|
||||
}
|
||||
|
||||
# Convert the analysis to JSON
|
||||
analysis_json = json.dumps(analysis, indent=2)
|
||||
|
||||
yield ToolResponse(id=CSV_ANALYSIS_RESPONSE_ID, response=analysis_json)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error analyzing CSV file: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
yield ToolResponse(id="ERROR", response=error_msg)
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
try:
|
||||
analysis_response = next(
|
||||
arg for arg in args if arg.id == CSV_ANALYSIS_RESPONSE_ID
|
||||
)
|
||||
return json.loads(analysis_response.response)
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Unexpected error in final_result: {str(e)}"}
|
||||
@@ -9,6 +9,8 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Tool as ToolDBModel
|
||||
from danswer.tools.analysis.analysis_tool import CSVAnalysisTool
|
||||
from danswer.tools.graphing.graphing_tool import GraphingTool
|
||||
from danswer.tools.images.image_generation_tool import ImageGenerationTool
|
||||
from danswer.tools.internet_search.internet_search_tool import InternetSearchTool
|
||||
from danswer.tools.search.search_tool import SearchTool
|
||||
@@ -41,6 +43,18 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
in_code_tool_id=ImageGenerationTool.__name__,
|
||||
display_name=ImageGenerationTool._DISPLAY_NAME,
|
||||
),
|
||||
InCodeToolInfo(
|
||||
cls=GraphingTool,
|
||||
description=("The graphing Tool allows the assistant to make graphs. "),
|
||||
in_code_tool_id=GraphingTool.__name__,
|
||||
display_name=GraphingTool._DISPLAY_NAME,
|
||||
),
|
||||
InCodeToolInfo(
|
||||
cls=CSVAnalysisTool,
|
||||
description=("The CSV Tool allows the assistant to make graphs. "),
|
||||
in_code_tool_id=CSVAnalysisTool.__name__,
|
||||
display_name=CSVAnalysisTool._DISPLAY_NAME,
|
||||
),
|
||||
# don't show the InternetSearchTool as an option if BING_API_KEY is not available
|
||||
*(
|
||||
[
|
||||
|
||||
@@ -144,7 +144,9 @@ class CustomTool(Tool):
|
||||
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, llm: None | LLM = None, **kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
request_body = kwargs.get(REQUEST_BODY)
|
||||
|
||||
path_params = {}
|
||||
|
||||
377
backend/danswer/tools/graphing/graphing_tool.py
Normal file
377
backend/danswer/tools/graphing/graphing_tool.py
Normal file
@@ -0,0 +1,377 @@
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import matplotlib
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
from danswer.db.engine import get_session_context_manager
|
||||
from danswer.db.models import FileOrigin
|
||||
from danswer.dynamic_configs.interface import JSON_ro
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import _process_csv_file
|
||||
from danswer.prompts.chat_prompts import (
|
||||
GRAPHING_GET_FILE_NAME_PROMPT,
|
||||
) # You'll need to create this
|
||||
from danswer.tools.graphing.models import GRAPHING_RESPONSE_ID
|
||||
from danswer.tools.graphing.models import GraphingError
|
||||
from danswer.tools.graphing.models import GraphingResponse
|
||||
from danswer.tools.graphing.models import GraphType
|
||||
from danswer.tools.tool import Tool
|
||||
from danswer.tools.tool import ToolResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
matplotlib.use("Agg") # Use non-interactive backend
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
FINAL_GRAPH_IMAGE = "final_graph_image"
|
||||
|
||||
YES_GRAPHING = "Yes Graphing"
|
||||
SKIP_GRAPHING = "Skip Graphing"
|
||||
|
||||
GRAPHING_TEMPLATE = f"""
|
||||
Given the conversation history and a follow up query,
|
||||
determine if the system should create a graph to better answer the latest user input.
|
||||
Your default response is {SKIP_GRAPHING}.
|
||||
Respond "{YES_GRAPHING}" if:
|
||||
- The user is asking for information that would be better represented in a graph.
|
||||
- The user explicitly requests a graph or chart.
|
||||
Conversation History:
|
||||
{{chat_history}}
|
||||
If you are at all unsure, respond with {SKIP_GRAPHING}.
|
||||
Respond with EXACTLY and ONLY "{YES_GRAPHING}" or "{SKIP_GRAPHING}"
|
||||
Follow Up Input:
|
||||
{{final_query}}
|
||||
""".strip()
|
||||
|
||||
|
||||
system_message = """
|
||||
You create Python code for graphs using matplotlib. Your code should:
|
||||
Import libraries: matplotlib.pyplot as plt, numpy as np
|
||||
Define data (create sample data if needed)
|
||||
Create the plot:
|
||||
Use fig, ax = plt.subplots(figsize=(10, 6))
|
||||
Use ax methods for plotting (e.g., ax.plot(), ax.bar())
|
||||
Set labels, title, legend using ax methods
|
||||
Not include plt.show()
|
||||
Key points:
|
||||
Use 'ax' for all plotting functions
|
||||
Use 'plt' only for figure creation and any necessary global settings
|
||||
Provide raw Python code without formatting
|
||||
"""
|
||||
|
||||
TabError
|
||||
|
||||
|
||||
class GraphingTool(Tool):
|
||||
_NAME = "create_graph"
|
||||
_DISPLAY_NAME = "Graphing Tool"
|
||||
_DESCRIPTION = system_message
|
||||
|
||||
def __init__(self, output_dir: str = "generated_graphs"):
|
||||
self.output_dir = output_dir
|
||||
try:
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating output directory: {e}")
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._NAME
|
||||
|
||||
@property
|
||||
def description(self) -> str:
|
||||
return self._DESCRIPTION
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
return self._DISPLAY_NAME
|
||||
|
||||
# def tool_definition(self) -> dict:
|
||||
# return {
|
||||
# "type": "function",
|
||||
# "function": {
|
||||
# "name": self.name,
|
||||
# "description": self.description,
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "code": {
|
||||
# "type": "string",
|
||||
# "description": "Python code to generate the graph using matplotlib and seaborn",
|
||||
# },
|
||||
# },
|
||||
# "required": ["code"],
|
||||
# },
|
||||
# },
|
||||
# }
|
||||
|
||||
def tool_definition(self) -> dict:
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"filename": {
|
||||
"type": "string",
|
||||
"description": "The name of the file to analyze for graphing purposes",
|
||||
},
|
||||
},
|
||||
"required": ["filename"],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
def check_if_needs_graphing(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
) -> bool:
|
||||
# return True
|
||||
history_str = "\n".join([f"{m.message}" for m in history])
|
||||
prompt = GRAPHING_TEMPLATE.format(
|
||||
chat_history=history_str,
|
||||
final_query=query,
|
||||
)
|
||||
use_graphing_output = llm.invoke(prompt)
|
||||
print(use_graphing_output)
|
||||
|
||||
logger.debug(f"Evaluated if should use graphing: {use_graphing_output}")
|
||||
content = use_graphing_output.content
|
||||
|
||||
return YES_GRAPHING.lower() in str(content).lower()
|
||||
|
||||
def get_args_for_non_tool_calling_llm(
|
||||
self,
|
||||
query: str,
|
||||
history: list[PreviousMessage],
|
||||
llm: LLM,
|
||||
force_run: bool = False,
|
||||
) -> dict[str, Any] | None:
|
||||
if not force_run and not self.check_if_needs_graphing(query, history, llm):
|
||||
return None
|
||||
|
||||
file_names_and_descriptions = []
|
||||
for message in history:
|
||||
for file in message.files:
|
||||
if file.file_type == ChatFileType.CSV:
|
||||
file_name = file.filename
|
||||
description = _process_csv_file(file)
|
||||
file_names_and_descriptions.append(f"{file_name}: {description}")
|
||||
|
||||
# Use the GRAPHING_GET_FILE_NAME_PROMPT to get the file name
|
||||
file_list = "\n".join(file_names_and_descriptions)
|
||||
prompt = GRAPHING_GET_FILE_NAME_PROMPT.format(
|
||||
chat_history="\n".join([f"{m.message}" for m in history]),
|
||||
question=query,
|
||||
file_list=file_list,
|
||||
)
|
||||
|
||||
file_name_response = llm.invoke(prompt)
|
||||
file_name = (
|
||||
file_name_response.content
|
||||
if isinstance(file_name_response.content, str)
|
||||
else ""
|
||||
)
|
||||
file_name = file_name.strip()
|
||||
|
||||
# Validate that the returned file name is in our list of available files
|
||||
available_files = [
|
||||
name.split(":")[0].strip() for name in file_names_and_descriptions
|
||||
]
|
||||
if file_name not in available_files:
|
||||
logger.warning(f"LLM returned invalid file name: {file_name}")
|
||||
file_name = available_files[0] if available_files else None
|
||||
|
||||
return {
|
||||
"filename": file_name,
|
||||
}
|
||||
|
||||
def build_tool_message_content(
|
||||
self, *args: ToolResponse
|
||||
) -> str | list[str | dict[str, Any]]:
|
||||
graph_response = next(arg for arg in args if arg.id == GRAPHING_RESPONSE_ID)
|
||||
return json.dumps(graph_response.response.dict())
|
||||
|
||||
@staticmethod
|
||||
def preprocess_code(code: str) -> str:
|
||||
# Extract code between triple backticks
|
||||
code_match = re.search(r"```python\n(.*?)```", code, re.DOTALL)
|
||||
if code_match:
|
||||
return code_match.group(1).strip()
|
||||
# If no code block is found, remove any explanatory text and return the rest
|
||||
return re.sub(r"^.*?import", "import", code, flags=re.DOTALL).strip()
|
||||
|
||||
@staticmethod
|
||||
def is_line_plot(ax: plt.Axes) -> bool:
|
||||
return len(ax.lines) > 0
|
||||
|
||||
@staticmethod
|
||||
def is_bar_plot(ax: plt.Axes) -> bool:
|
||||
return len(ax.patches) > 0 and isinstance(ax.patches[0], plt.Rectangle)
|
||||
|
||||
@staticmethod
|
||||
def extract_line_plot_data(ax: plt.Axes) -> dict[str, Any]:
|
||||
data = []
|
||||
for line in ax.lines:
|
||||
line_data = {
|
||||
"x": line.get_xdata(), # type: ignore
|
||||
"y": line.get_ydata(), # type: ignore
|
||||
"label": line.get_label(),
|
||||
"color": line.get_color(),
|
||||
}
|
||||
data.append(line_data)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"title": ax.get_title(),
|
||||
"xlabel": ax.get_xlabel(),
|
||||
"ylabel": ax.get_ylabel(),
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def extract_bar_plot_data(ax: plt.Axes) -> dict[str, Any]:
|
||||
data = []
|
||||
for patch in ax.patches:
|
||||
bar_data = {
|
||||
"x": float(patch.get_bbox().x0 + patch.get_bbox().width / 2), # type: ignore
|
||||
"y": float(patch.get_bbox().height), # type: ignore
|
||||
"width": float(patch.get_bbox().width), # type: ignore
|
||||
"color": patch.get_facecolor(),
|
||||
}
|
||||
data.append(bar_data)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"title": ax.get_title(),
|
||||
"xlabel": ax.get_xlabel(),
|
||||
"ylabel": ax.get_ylabel(),
|
||||
"xticks": ax.get_xticks(),
|
||||
"xticklabels": [label.get_text() for label in ax.get_xticklabels()],
|
||||
}
|
||||
|
||||
def run(
|
||||
self, llm: LLM | None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
if llm is None:
|
||||
raise ValueError("This tool requires an LLM to run")
|
||||
|
||||
file_content = kwargs["filename"]
|
||||
file_content = file_content.decode("utf-8") # type: ignore
|
||||
csv_file = StringIO(file_content)
|
||||
df = pd.read_csv(csv_file)
|
||||
|
||||
# Generate a summary of the CSV data
|
||||
data_summary = df.describe().to_string()
|
||||
columns_info = df.dtypes.to_string()
|
||||
|
||||
# Create a prompt for the LLM to generate the plotting code
|
||||
code_generation_prompt = f"""
|
||||
{system_message}
|
||||
|
||||
Given the following CSV data summary and user query, create Python code to generate an appropriate graph:
|
||||
|
||||
Data Summary:
|
||||
{data_summary}
|
||||
|
||||
Columns:
|
||||
{columns_info}
|
||||
|
||||
User Query:
|
||||
"Graph the data"
|
||||
|
||||
Generate the Python code to create the graph:
|
||||
"""
|
||||
|
||||
code_response = llm.invoke(code_generation_prompt)
|
||||
code = self.preprocess_code(cast(str, code_response.content))
|
||||
|
||||
# Continue with the existing code to execute and process the graph
|
||||
locals_dict = {"plt": plt, "matplotlib": matplotlib, "np": np, "df": df}
|
||||
file_id = None
|
||||
|
||||
try:
|
||||
exec(code, globals(), locals_dict)
|
||||
|
||||
fig = locals_dict.get("fig")
|
||||
|
||||
if fig is None:
|
||||
raise ValueError("The provided code did not create a 'fig' variable")
|
||||
|
||||
ax = fig.gca() # type: ignore
|
||||
|
||||
plot_data = None
|
||||
plot_type: GraphType | None = None
|
||||
if self.is_line_plot(ax):
|
||||
plot_data = self.extract_line_plot_data(ax)
|
||||
plot_type = GraphType.LINE_GRAPH
|
||||
elif self.is_bar_plot(ax):
|
||||
plot_data = self.extract_bar_plot_data(ax)
|
||||
plot_type = GraphType.BAR_CHART
|
||||
|
||||
if plot_data:
|
||||
plot_data_file = os.path.join(self.output_dir, "plot_data.json")
|
||||
with open(plot_data_file, "w") as f:
|
||||
json.dump(plot_data, f)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_id = str(uuid.uuid4())
|
||||
|
||||
json_content = json.dumps(plot_data)
|
||||
json_bytes = json_content.encode("utf-8")
|
||||
|
||||
file_store.save_file(
|
||||
file_name=file_id,
|
||||
content=BytesIO(json_bytes),
|
||||
display_name="temporary",
|
||||
file_origin=FileOrigin.CHAT_UPLOAD,
|
||||
file_type="json",
|
||||
)
|
||||
|
||||
buf = BytesIO()
|
||||
fig.savefig(buf, format="png", bbox_inches="tight") # type: ignore
|
||||
with open("aaa garp.png", "wb") as f:
|
||||
f.write(buf.getvalue())
|
||||
|
||||
yield ToolResponse(
|
||||
id=GRAPHING_RESPONSE_ID,
|
||||
response=GraphingResponse(
|
||||
file_id=str(file_id),
|
||||
graph_type=plot_type.value # type: ignore
|
||||
if plot_type
|
||||
else None, # Use .value to get the string
|
||||
plot_data=plot_data, # Pass the dictionary directly, not as a JSON string
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error generating graph: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
yield ToolResponse(id="ERROR", response=GraphingError(error=error_msg))
|
||||
|
||||
def final_result(self, *args: ToolResponse) -> JSON_ro:
|
||||
try:
|
||||
graph_response = next(arg for arg in args if arg.id == GRAPHING_RESPONSE_ID)
|
||||
return graph_response.response.dict()
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"Unexpected error in final_result: {str(e)}"}
|
||||
27
backend/danswer/tools/graphing/models.py
Normal file
27
backend/danswer/tools/graphing/models.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class GraphGenerationDisplay(BaseModel):
|
||||
file_id: str
|
||||
line_graph: bool
|
||||
|
||||
|
||||
class GraphType(str, Enum):
|
||||
BAR_CHART = "bar_chart"
|
||||
LINE_GRAPH = "line_graph"
|
||||
|
||||
|
||||
class GraphingResponse(BaseModel):
|
||||
file_id: str
|
||||
plot_data: dict[str, Any] | None
|
||||
graph_type: GraphType
|
||||
|
||||
|
||||
class GraphingError(BaseModel):
|
||||
error: str
|
||||
|
||||
|
||||
GRAPHING_RESPONSE_ID = "graphing_response"
|
||||
8
backend/danswer/tools/graphing/prompt.py
Normal file
8
backend/danswer/tools/graphing/prompt.py
Normal file
@@ -0,0 +1,8 @@
|
||||
NON_TOOL_CALLING_PROMPT = """
|
||||
You have just created the attached images in response to the following query: "{{query}}".
|
||||
Can you please summarize them in a sentence or two?
|
||||
"""
|
||||
|
||||
TOOL_CALLING_PROMPT = """
|
||||
Can you please summarize the two images you generate in a sentence or two?
|
||||
"""
|
||||
@@ -223,7 +223,12 @@ class ImageGenerationTool(Tool):
|
||||
"An error occurred during image generation. Please try again later."
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, llm: LLM | None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
if llm is not None:
|
||||
logger.warning("LLM passed to ImageGenerationTool.run() but not used")
|
||||
|
||||
prompt = cast(str, kwargs["prompt"])
|
||||
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from danswer.llm.utils import build_content_with_imgs
|
||||
|
||||
|
||||
IMG_GENERATION_SUMMARY_PROMPT = """
|
||||
You have just created the attached images in response to the following query: "{query}".
|
||||
You have just created the most recent attached images in response to the following query: "{query}".
|
||||
|
||||
Can you please summarize them in a sentence or two? Do NOT include image urls or bulleted lists.
|
||||
"""
|
||||
|
||||
@@ -209,7 +209,12 @@ class InternetSearchTool(Tool):
|
||||
],
|
||||
)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, llm: LLM | None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
if llm is not None:
|
||||
logger.warning("LLM passed to InternetSearchTool.run() but not used")
|
||||
|
||||
query = cast(str, kwargs["internet_search_query"])
|
||||
|
||||
results = self._perform_search(query)
|
||||
|
||||
@@ -262,7 +262,12 @@ class SearchTool(Tool):
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, llm: LLM | None = None, **kwargs: str
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
if llm is not None:
|
||||
logger.warning("LLM passed to ImageGenerationTool.run() but not used")
|
||||
|
||||
query = cast(str, kwargs["query"])
|
||||
|
||||
if self.selected_sections:
|
||||
|
||||
@@ -51,7 +51,9 @@ class Tool(abc.ABC):
|
||||
"""Actual execution of the tool"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def run(self, **kwargs: Any) -> Generator[ToolResponse, None, None]:
|
||||
def run(
|
||||
self, llm: LLM | None = None, **kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
@@ -12,9 +13,10 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
|
||||
|
||||
|
||||
class ToolRunner:
|
||||
def __init__(self, tool: Tool, args: dict[str, Any]):
|
||||
def __init__(self, tool: Tool, args: dict[str, Any], llm: LLM | None = None):
|
||||
self.tool = tool
|
||||
self.args = args
|
||||
self._llm = llm
|
||||
|
||||
self._tool_responses: list[ToolResponse] | None = None
|
||||
|
||||
@@ -22,12 +24,25 @@ class ToolRunner:
|
||||
return ToolCallKickoff(tool_name=self.tool.name, tool_args=self.args)
|
||||
|
||||
def tool_responses(self) -> Generator[ToolResponse, None, None]:
|
||||
print("i am in the tool responses function")
|
||||
if self._tool_responses is not None:
|
||||
print("prev")
|
||||
print(self._tool_responses)
|
||||
|
||||
yield from self._tool_responses
|
||||
return
|
||||
|
||||
tool_responses: list[ToolResponse] = []
|
||||
for tool_response in self.tool.run(**self.args):
|
||||
print("runinnig the tool")
|
||||
print(self.tool.name)
|
||||
|
||||
for tool_response in self.tool.run(llm=self._llm, **self.args):
|
||||
if isinstance(tool_response.response, bytes):
|
||||
tool_response.response = base64.b64encode(
|
||||
tool_response.response
|
||||
).decode("utf-8")
|
||||
|
||||
print("tool response")
|
||||
yield tool_response
|
||||
tool_responses.append(tool_response)
|
||||
|
||||
@@ -52,4 +67,5 @@ def check_which_tools_should_run_for_non_tool_calling_llm(
|
||||
(tool.get_args_for_non_tool_calling_llm, (query, history, llm))
|
||||
for tool in tools
|
||||
]
|
||||
|
||||
return run_functions_tuples_in_parallel(tool_args_list)
|
||||
|
||||
@@ -43,6 +43,7 @@ def select_single_tool_for_non_tool_calling_llm(
|
||||
llm: LLM,
|
||||
) -> tuple[Tool, dict[str, Any]] | None:
|
||||
if len(tools_and_args) == 1:
|
||||
logger.info("Only one tool available, returning it directly")
|
||||
return tools_and_args[0]
|
||||
|
||||
tool_list_str = "\n".join(
|
||||
@@ -58,21 +59,26 @@ def select_single_tool_for_non_tool_calling_llm(
|
||||
tool_list=tool_list_str, chat_history=history_str, query=query
|
||||
)
|
||||
output = message_to_string(llm.invoke(prompt))
|
||||
logger.info(f"LLM output for tool selection: {output}")
|
||||
try:
|
||||
# First try to match the number
|
||||
number_match = re.search(r"\d+", output)
|
||||
if number_match:
|
||||
tool_ind = int(number_match.group())
|
||||
logger.info(f"Selected tool by index: {tool_ind}")
|
||||
return tools_and_args[tool_ind]
|
||||
|
||||
# If that fails, try to match the tool name
|
||||
for tool, args in tools_and_args:
|
||||
if tool.name.lower() in output.lower():
|
||||
logger.info(f"Selected tool by name: {tool.name}")
|
||||
return tool, args
|
||||
|
||||
# If that fails, return the first tool
|
||||
logger.warning("Failed to match tool by index or name, returning first tool")
|
||||
return tools_and_args[0]
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to select single tool for non-tool-calling LLM: {output}")
|
||||
logger.exception(e)
|
||||
return None
|
||||
|
||||
1
backend/output/plot_data.json
Normal file
1
backend/output/plot_data.json
Normal file
@@ -0,0 +1 @@
|
||||
{"data": [{"x": 0.0, "y": 91.0, "width": 0.8, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 1.0, "y": 46.0, "width": 0.7999999999999999, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 2.0, "y": 26.41338, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 3.0, "y": 1.0, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 4.0, "y": 23.5, "width": 0.8000000000000003, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 5.0, "y": 46.0, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 6.0, "y": 68.5, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}, {"x": 7.0, "y": 91.0, "width": 0.7999999999999998, "color": [0.12156862745098039, 0.4666666666666667, 0.7058823529411765, 1.0]}], "title": "Summary Statistics of Id Column", "xlabel": "Statistic", "ylabel": "Value", "xticks": [0, 1, 2, 3, 4, 5, 6, 7], "xticklabels": ["count", "mean", "std", "min", "25%", "50%", "75%", "max"]}
|
||||
@@ -21,6 +21,8 @@ CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
|
||||
# Tool call configs
|
||||
MAX_TOOL_CALLS = 3
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
|
||||
BIN
backend/zample garp.png
Normal file
BIN
backend/zample garp.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
BIN
backend/zample garph.png
Normal file
BIN
backend/zample garph.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
BIN
backend/zzagraph.png
Normal file
BIN
backend/zzagraph.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 40 KiB |
@@ -211,6 +211,7 @@ services:
|
||||
- NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_NEGATIVE_PREDEFINED_FEEDBACK_OPTIONS:-}
|
||||
- NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT:-}
|
||||
- NEXT_PUBLIC_DISABLE_CSV_DISPLAY=${NEXT_PUBLIC_DISABLE_CSV_DISPLAY:-}
|
||||
- NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN=${NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN:-}
|
||||
|
||||
# Enterprise Edition only
|
||||
@@ -292,7 +293,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -302,7 +302,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432:5432"
|
||||
- "5433:5432"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -154,7 +154,7 @@ services:
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASSWORD:-password}
|
||||
ports:
|
||||
- "5432"
|
||||
- "5433"
|
||||
volumes:
|
||||
- db_volume:/var/lib/postgresql/data
|
||||
|
||||
|
||||
@@ -62,6 +62,9 @@ ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
|
||||
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
|
||||
|
||||
|
||||
ARG NEXT_PUBLIC_DISABLE_CSV_DISPLAY
|
||||
ENV NEXT_PUBLIC_DISABLE_CSV_DISPLAY=${NEXT_PUBLIC_DISABLE_CSV_DISPLAY}
|
||||
|
||||
RUN npx next build
|
||||
|
||||
# Step 2. Production image, copy all the files and run next
|
||||
@@ -122,6 +125,9 @@ ENV NEXT_PUBLIC_DISABLE_LOGOUT=${NEXT_PUBLIC_DISABLE_LOGOUT}
|
||||
ARG NEXT_PUBLIC_CUSTOM_REFRESH_URL
|
||||
ENV NEXT_PUBLIC_CUSTOM_REFRESH_URL=${NEXT_PUBLIC_CUSTOM_REFRESH_URL}
|
||||
|
||||
ARG NEXT_PUBLIC_DISABLE_CSV_DISPLAY
|
||||
ENV NEXT_PUBLIC_DISABLE_CSV_DISPLAY=${NEXT_PUBLIC_DISABLE_CSV_DISPLAY}
|
||||
|
||||
# Note: Don't expose ports here, Compose will handle that for us if necessary.
|
||||
# If you want to run this without compose, specify the ports to
|
||||
# expose via cli
|
||||
|
||||
17
web/components.json
Normal file
17
web/components.json
Normal file
@@ -0,0 +1,17 @@
|
||||
{
|
||||
"$schema": "https://ui.shadcn.com/schema.json",
|
||||
"style": "default",
|
||||
"rsc": true,
|
||||
"tsx": true,
|
||||
"tailwind": {
|
||||
"config": "tailwind.config.ts",
|
||||
"css": "src/app/globals.css",
|
||||
"baseColor": "neutral",
|
||||
"cssVariables": true,
|
||||
"prefix": ""
|
||||
},
|
||||
"aliases": {
|
||||
"components": "@/components",
|
||||
"utils": "@/lib/utils"
|
||||
}
|
||||
}
|
||||
199
web/package-lock.json
generated
199
web/package-lock.json
generated
@@ -14,7 +14,9 @@
|
||||
"@phosphor-icons/react": "^2.0.8",
|
||||
"@radix-ui/react-dialog": "^1.0.5",
|
||||
"@radix-ui/react-popover": "^1.0.7",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-tooltip": "^1.0.7",
|
||||
"@tanstack/react-table": "^8.19.3",
|
||||
"@tremor/react": "^3.9.2",
|
||||
"@types/js-cookie": "^3.0.3",
|
||||
"@types/lodash": "^4.17.0",
|
||||
@@ -24,9 +26,13 @@
|
||||
"@types/react-dom": "18.0.11",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"autoprefixer": "^10.4.14",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.1",
|
||||
"formik": "^2.2.9",
|
||||
"fs": "^0.0.1-security",
|
||||
"js-cookie": "^3.0.5",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.416.0",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"next": "^14.2.3",
|
||||
"npm": "^10.8.0",
|
||||
@@ -39,12 +45,15 @@
|
||||
"react-loader-spinner": "^5.4.5",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-select": "^5.8.0",
|
||||
"recharts": "^2.12.7",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"semver": "^7.5.4",
|
||||
"sharp": "^0.32.6",
|
||||
"swr": "^2.1.5",
|
||||
"tailwind-merge": "^2.4.0",
|
||||
"tailwindcss": "^3.3.1",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"typescript": "5.0.3",
|
||||
"uuid": "^9.0.1",
|
||||
"yup": "^1.1.1"
|
||||
@@ -1249,6 +1258,25 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-dialog/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz",
|
||||
"integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.13.10",
|
||||
"@radix-ui/react-compose-refs": "1.0.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-dismissable-layer": {
|
||||
"version": "1.0.5",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-dismissable-layer/-/react-dismissable-layer-1.0.5.tgz",
|
||||
@@ -1373,6 +1401,25 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-popover/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz",
|
||||
"integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.13.10",
|
||||
"@radix-ui/react-compose-refs": "1.0.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-popper": {
|
||||
"version": "1.1.3",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-popper/-/react-popper-1.1.3.tgz",
|
||||
@@ -1475,10 +1522,11 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slot": {
|
||||
"node_modules/@radix-ui/react-primitive/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz",
|
||||
"integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.13.10",
|
||||
"@radix-ui/react-compose-refs": "1.0.1"
|
||||
@@ -1493,6 +1541,39 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.0.tgz",
|
||||
"integrity": "sha512-FUCf5XMfmW4dtYl69pdS4DbxKy8nj4M7SafBgPllysxmdachynNflAdp/gCsnYWNDnge6tI9onzMp5ARYc1KNw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slot/node_modules/@radix-ui/react-compose-refs": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.0.tgz",
|
||||
"integrity": "sha512-b4inOtiaOnYf9KWyO3jAeeCG6FeyfY6ldiEPanbUjWd+xIk5wZeHa8yVwmrJ2vderhu/BQvzCrJI0lHd+wIiqw==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip": {
|
||||
"version": "1.0.7",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-tooltip/-/react-tooltip-1.0.7.tgz",
|
||||
@@ -1527,6 +1608,25 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-tooltip/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.0.2.tgz",
|
||||
"integrity": "sha512-YeTpuq4deV+6DusvVUW4ivBgnkHwECUu0BiN43L5UCDFgdhsRUWAghhTF5MbvNTPzmiFOx90asDSUjWuCNapwg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@babel/runtime": "^7.13.10",
|
||||
"@radix-ui/react-compose-refs": "1.0.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-use-callback-ref": {
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-use-callback-ref/-/react-use-callback-ref-1.0.1.tgz",
|
||||
@@ -1699,6 +1799,26 @@
|
||||
"tailwindcss": ">=3.0.0 || insiders"
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/react-table": {
|
||||
"version": "8.19.3",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-table/-/react-table-8.19.3.tgz",
|
||||
"integrity": "sha512-MtgPZc4y+cCRtU16y1vh1myuyZ2OdkWgMEBzyjYsoMWMicKZGZvcDnub3Zwb6XF2pj9iRMvm1SO1n57lS0vXLw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@tanstack/table-core": "8.19.3"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": ">=16.8",
|
||||
"react-dom": ">=16.8"
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/react-virtual": {
|
||||
"version": "3.5.0",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/react-virtual/-/react-virtual-3.5.0.tgz",
|
||||
@@ -1715,6 +1835,19 @@
|
||||
"react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/table-core": {
|
||||
"version": "8.19.3",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/table-core/-/table-core-8.19.3.tgz",
|
||||
"integrity": "sha512-IqREj9ADoml9zCAouIG/5kCGoyIxPFdqdyoxis9FisXFi5vT+iYfEfLosq4xkU/iDbMcEuAj+X8dWRLvKYDNoQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/tannerlinsley"
|
||||
}
|
||||
},
|
||||
"node_modules/@tanstack/virtual-core": {
|
||||
"version": "3.5.0",
|
||||
"resolved": "https://registry.npmjs.org/@tanstack/virtual-core/-/virtual-core-3.5.0.tgz",
|
||||
@@ -1743,6 +1876,16 @@
|
||||
"react-dom": ">=16.6.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@tremor/react/node_modules/tailwind-merge": {
|
||||
"version": "1.14.0",
|
||||
"resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-1.14.0.tgz",
|
||||
"integrity": "sha512-3mFKyCo/MBcgyOTlrY8T7odzZFx+w+qKSMAmdFzRvqBfLlSigU6TZnlFHK0lkMwj9Bj8OYU+9yW9lmGuS0QEnQ==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/dcastil"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/d3-array": {
|
||||
"version": "3.2.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/d3-array/-/d3-array-3.2.1.tgz",
|
||||
@@ -2792,6 +2935,27 @@
|
||||
"resolved": "https://registry.npmjs.org/chownr/-/chownr-1.1.4.tgz",
|
||||
"integrity": "sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg=="
|
||||
},
|
||||
"node_modules/class-variance-authority": {
|
||||
"version": "0.7.0",
|
||||
"resolved": "https://registry.npmjs.org/class-variance-authority/-/class-variance-authority-0.7.0.tgz",
|
||||
"integrity": "sha512-jFI8IQw4hczaL4ALINxqLEXQbWcNjoSkloa4IaufXCJr6QawJyw7tuRysRsrE8w2p/4gGaxKIt/hX3qz/IbD1A==",
|
||||
"license": "Apache-2.0",
|
||||
"dependencies": {
|
||||
"clsx": "2.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"url": "https://joebell.co.uk"
|
||||
}
|
||||
},
|
||||
"node_modules/class-variance-authority/node_modules/clsx": {
|
||||
"version": "2.0.0",
|
||||
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.0.0.tgz",
|
||||
"integrity": "sha512-rQ1+kcj+ttHG0MKVGBUXwayCCF1oh39BF5COIpRzuCEv8Mwjv0XucrI2ExNTOn9IlLifGClWQcU9BrZORvtw6Q==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/client-only": {
|
||||
"version": "0.0.1",
|
||||
"resolved": "https://registry.npmjs.org/client-only/-/client-only-0.0.1.tgz",
|
||||
@@ -2801,6 +2965,7 @@
|
||||
"version": "2.1.1",
|
||||
"resolved": "https://registry.npmjs.org/clsx/-/clsx-2.1.1.tgz",
|
||||
"integrity": "sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
@@ -4172,6 +4337,12 @@
|
||||
"url": "https://github.com/sponsors/rawify"
|
||||
}
|
||||
},
|
||||
"node_modules/fs": {
|
||||
"version": "0.0.1-security",
|
||||
"resolved": "https://registry.npmjs.org/fs/-/fs-0.0.1-security.tgz",
|
||||
"integrity": "sha512-3XY9e1pP0CVEUCdj5BmfIZxRBTSDycnbqhIOGec9QYtmVH2fbLpj86CFWkrNOkt/Fvty4KZG5lTglL9j/gJ87w==",
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/fs-constants": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/fs-constants/-/fs-constants-1.0.0.tgz",
|
||||
@@ -5477,6 +5648,15 @@
|
||||
"node": "14 || >=16.14"
|
||||
}
|
||||
},
|
||||
"node_modules/lucide-react": {
|
||||
"version": "0.416.0",
|
||||
"resolved": "https://registry.npmjs.org/lucide-react/-/lucide-react-0.416.0.tgz",
|
||||
"integrity": "sha512-wPWxTzdss1CTz2aqcNWNlbh4YSnH9neJWP3RaeXepxpLCTW+pmu7WcT/wxJe+Q7Y7DqGOxAqakJv0pIK3431Ag==",
|
||||
"license": "ISC",
|
||||
"peerDependencies": {
|
||||
"react": "^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/markdown-table": {
|
||||
"version": "3.0.3",
|
||||
"resolved": "https://registry.npmjs.org/markdown-table/-/markdown-table-3.0.3.tgz",
|
||||
@@ -9856,6 +10036,7 @@
|
||||
"version": "2.12.7",
|
||||
"resolved": "https://registry.npmjs.org/recharts/-/recharts-2.12.7.tgz",
|
||||
"integrity": "sha512-hlLJMhPQfv4/3NBSAyq3gzGg4h2v69RJh6KU7b3pXYNNAELs9kEoXOjbkxdXpALqKBoVmVptGfLpxdaVYqjmXQ==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"clsx": "^2.0.0",
|
||||
"eventemitter3": "^4.0.1",
|
||||
@@ -10777,9 +10958,10 @@
|
||||
"integrity": "sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew=="
|
||||
},
|
||||
"node_modules/tailwind-merge": {
|
||||
"version": "1.14.0",
|
||||
"resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-1.14.0.tgz",
|
||||
"integrity": "sha512-3mFKyCo/MBcgyOTlrY8T7odzZFx+w+qKSMAmdFzRvqBfLlSigU6TZnlFHK0lkMwj9Bj8OYU+9yW9lmGuS0QEnQ==",
|
||||
"version": "2.4.0",
|
||||
"resolved": "https://registry.npmjs.org/tailwind-merge/-/tailwind-merge-2.4.0.tgz",
|
||||
"integrity": "sha512-49AwoOQNKdqKPd9CViyH5wJoSKsCDjUlzL8DxuGp3P1FsGY36NJDAa18jLZcaHAUUuTj+JB8IAo8zWgBNvBF7A==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/dcastil"
|
||||
@@ -10821,6 +11003,15 @@
|
||||
"node": ">=14.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/tailwindcss-animate": {
|
||||
"version": "1.0.7",
|
||||
"resolved": "https://registry.npmjs.org/tailwindcss-animate/-/tailwindcss-animate-1.0.7.tgz",
|
||||
"integrity": "sha512-bl6mpH3T7I3UFxuvDEXLxy/VuFxBk5bbzplh7tXI68mwMokNYd1t9qPBHlnyTwfa4JGC4zP516I1hYYtQ/vspA==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"tailwindcss": ">=3.0.0 || insiders"
|
||||
}
|
||||
},
|
||||
"node_modules/tailwindcss/node_modules/postcss-selector-parser": {
|
||||
"version": "6.0.16",
|
||||
"resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-6.0.16.tgz",
|
||||
|
||||
@@ -15,7 +15,9 @@
|
||||
"@phosphor-icons/react": "^2.0.8",
|
||||
"@radix-ui/react-dialog": "^1.0.5",
|
||||
"@radix-ui/react-popover": "^1.0.7",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-tooltip": "^1.0.7",
|
||||
"@tanstack/react-table": "^8.19.3",
|
||||
"@tremor/react": "^3.9.2",
|
||||
"@types/js-cookie": "^3.0.3",
|
||||
"@types/lodash": "^4.17.0",
|
||||
@@ -25,9 +27,13 @@
|
||||
"@types/react-dom": "18.0.11",
|
||||
"@types/uuid": "^9.0.8",
|
||||
"autoprefixer": "^10.4.14",
|
||||
"class-variance-authority": "^0.7.0",
|
||||
"clsx": "^2.1.1",
|
||||
"formik": "^2.2.9",
|
||||
"fs": "^0.0.1-security",
|
||||
"js-cookie": "^3.0.5",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.416.0",
|
||||
"mdast-util-find-and-replace": "^3.0.1",
|
||||
"next": "^14.2.3",
|
||||
"npm": "^10.8.0",
|
||||
@@ -40,12 +46,15 @@
|
||||
"react-loader-spinner": "^5.4.5",
|
||||
"react-markdown": "^9.0.1",
|
||||
"react-select": "^5.8.0",
|
||||
"recharts": "^2.12.7",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"semver": "^7.5.4",
|
||||
"sharp": "^0.32.6",
|
||||
"swr": "^2.1.5",
|
||||
"tailwind-merge": "^2.4.0",
|
||||
"tailwindcss": "^3.3.1",
|
||||
"tailwindcss-animate": "^1.0.7",
|
||||
"typescript": "5.0.3",
|
||||
"uuid": "^9.0.1",
|
||||
"yup": "^1.1.1"
|
||||
|
||||
@@ -154,7 +154,6 @@ export default function AddConnector({
|
||||
initialValues={createConnectorInitialValues(connector)}
|
||||
validationSchema={createConnectorValidationSchema(connector)}
|
||||
onSubmit={async (values) => {
|
||||
console.log(" Iam submiing the connector");
|
||||
const {
|
||||
name,
|
||||
groups,
|
||||
|
||||
44
web/src/app/chat/AIMessageSequenceUtils.ts
Normal file
44
web/src/app/chat/AIMessageSequenceUtils.ts
Normal file
@@ -0,0 +1,44 @@
|
||||
// This module handles AI message sequences - consecutive AI messages that are streamed
|
||||
// separately but represent a single logical message. These utilities are used for
|
||||
// processing and displaying such sequences in the chat interface.
|
||||
|
||||
import { Message } from "@/app/chat/interfaces";
|
||||
import { DanswerDocument } from "@/lib/search/interfaces";
|
||||
|
||||
// Retrieves the consecutive AI messages at the end of the message history.
|
||||
// This is useful for combining or processing the latest AI response sequence.
|
||||
export function getConsecutiveAIMessagesAtEnd(
|
||||
messageHistory: Message[]
|
||||
): Message[] {
|
||||
const aiMessages = [];
|
||||
for (let i = messageHistory.length - 1; i >= 0; i--) {
|
||||
if (messageHistory[i]?.type === "assistant") {
|
||||
aiMessages.unshift(messageHistory[i]);
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
return aiMessages;
|
||||
}
|
||||
|
||||
// Extracts unique documents from a sequence of AI messages.
|
||||
// This is used to compile a comprehensive list of referenced documents
|
||||
// across multiple parts of an AI response.
|
||||
export function getUniqueDocumentsFromAIMessages(
|
||||
messages: Message[]
|
||||
): DanswerDocument[] {
|
||||
const uniqueDocumentsMap = new Map<string, DanswerDocument>();
|
||||
|
||||
messages.forEach((message) => {
|
||||
if (message.documents) {
|
||||
message.documents.forEach((doc) => {
|
||||
const uniqueKey = `${doc.document_id}-${doc.chunk_ind}`;
|
||||
if (!uniqueDocumentsMap.has(uniqueKey)) {
|
||||
uniqueDocumentsMap.set(uniqueKey, doc);
|
||||
}
|
||||
});
|
||||
}
|
||||
});
|
||||
|
||||
return Array.from(uniqueDocumentsMap.values());
|
||||
}
|
||||
@@ -101,6 +101,8 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { SEARCH_TOOL_NAME } from "./tools/constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { ApiKeyModal } from "@/components/llm/ApiKeyModal";
|
||||
import { Button } from "@tremor/react";
|
||||
import dynamic from "next/dynamic";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -133,7 +135,6 @@ export function ChatPage({
|
||||
} = useChatContext();
|
||||
|
||||
const [showApiKeyModal, setShowApiKeyModal] = useState(true);
|
||||
|
||||
const { user, refreshUser, isLoadingUser } = useUser();
|
||||
|
||||
// chat session
|
||||
@@ -248,13 +249,13 @@ export function ChatPage({
|
||||
if (
|
||||
lastMessage &&
|
||||
lastMessage.type === "assistant" &&
|
||||
lastMessage.toolCalls[0] &&
|
||||
lastMessage.toolCalls[0].tool_result === undefined
|
||||
lastMessage.toolCall &&
|
||||
lastMessage.toolCall.tool_result === undefined
|
||||
) {
|
||||
const newCompleteMessageMap = new Map(
|
||||
currentMessageMap(completeMessageDetail)
|
||||
);
|
||||
const updatedMessage = { ...lastMessage, toolCalls: [] };
|
||||
const updatedMessage = { ...lastMessage, toolCall: null };
|
||||
newCompleteMessageMap.set(lastMessage.messageId, updatedMessage);
|
||||
updateCompleteMessageDetail(currentSession, newCompleteMessageMap);
|
||||
}
|
||||
@@ -483,7 +484,7 @@ export function ChatPage({
|
||||
message: "",
|
||||
type: "system",
|
||||
files: [],
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: null,
|
||||
childrenMessageIds: [firstMessageId],
|
||||
latestChildMessageId: firstMessageId,
|
||||
@@ -510,6 +511,7 @@ export function ChatPage({
|
||||
}
|
||||
newCompleteMessageMap.set(message.messageId, message);
|
||||
});
|
||||
|
||||
// if specified, make these new message the latest of the current message chain
|
||||
if (makeLatestChildMessage) {
|
||||
const currentMessageChain = buildLatestMessageChain(
|
||||
@@ -1044,8 +1046,6 @@ export function ChatPage({
|
||||
resetInputBar();
|
||||
let messageUpdates: Message[] | null = null;
|
||||
|
||||
let answer = "";
|
||||
|
||||
let stopReason: StreamStopReason | null = null;
|
||||
let query: string | null = null;
|
||||
let retrievalType: RetrievalType =
|
||||
@@ -1058,12 +1058,14 @@ export function ChatPage({
|
||||
let stackTrace: string | null = null;
|
||||
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
let toolCalls: ToolCallMetadata[] = [];
|
||||
let toolCall: ToolCallMetadata | null = null;
|
||||
|
||||
let initialFetchDetails: null | {
|
||||
user_message_id: number;
|
||||
assistant_message_id: number;
|
||||
frozenMessageMap: Map<number, Message>;
|
||||
initialDynamicParentMessage: Message;
|
||||
initialDynamicAssistantMessage: Message;
|
||||
} = null;
|
||||
|
||||
try {
|
||||
@@ -1122,7 +1124,16 @@ export function ChatPage({
|
||||
return new Promise((resolve) => setTimeout(resolve, ms));
|
||||
};
|
||||
|
||||
let updateFn = (messages: Message[]) => {
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
chatSessionId: currChatSessionId,
|
||||
});
|
||||
};
|
||||
|
||||
await delay(50);
|
||||
let dynamicParentMessage: Message | null = null;
|
||||
let dynamicAssistantMessage: Message | null = null;
|
||||
while (!stack.isComplete || !stack.isEmpty()) {
|
||||
await delay(0.5);
|
||||
|
||||
@@ -1132,6 +1143,7 @@ export function ChatPage({
|
||||
continue;
|
||||
}
|
||||
|
||||
console.log(packet);
|
||||
if (!initialFetchDetails) {
|
||||
if (!Object.hasOwn(packet, "user_message_id")) {
|
||||
console.error(
|
||||
@@ -1156,12 +1168,12 @@ export function ChatPage({
|
||||
messageUpdates = [
|
||||
{
|
||||
messageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
? regenerationRequest?.messageId
|
||||
: user_message_id,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
];
|
||||
@@ -1176,22 +1188,109 @@ export function ChatPage({
|
||||
});
|
||||
}
|
||||
|
||||
const { messageMap: currentFrozenMessageMap } =
|
||||
let { messageMap: currentFrozenMessageMap } =
|
||||
upsertToCompleteMessageMap({
|
||||
messages: messageUpdates,
|
||||
chatSessionId: currChatSessionId,
|
||||
});
|
||||
|
||||
const frozenMessageMap = currentFrozenMessageMap;
|
||||
let frozenMessageMap = currentFrozenMessageMap;
|
||||
regenerationRequest?.parentMessage;
|
||||
let initialDynamicParentMessage: Message = regenerationRequest
|
||||
? regenerationRequest?.parentMessage
|
||||
: {
|
||||
messageId: user_message_id!,
|
||||
message: "",
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCall: null,
|
||||
parentMessageId: error ? null : lastSuccessfulMessageId,
|
||||
childrenMessageIds: [assistant_message_id!],
|
||||
latestChildMessageId: -100,
|
||||
};
|
||||
|
||||
let initialDynamicAssistantMessage: Message = {
|
||||
messageId: assistant_message_id!,
|
||||
message: "",
|
||||
type: "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents: finalMessage?.context_docs?.top_documents || documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCall: finalMessage?.tool_call || toolCall,
|
||||
parentMessageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: user_message_id,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
};
|
||||
|
||||
initialFetchDetails = {
|
||||
frozenMessageMap,
|
||||
assistant_message_id,
|
||||
user_message_id,
|
||||
initialDynamicParentMessage,
|
||||
initialDynamicAssistantMessage,
|
||||
};
|
||||
|
||||
resetRegenerationState();
|
||||
} else {
|
||||
const { user_message_id, frozenMessageMap } = initialFetchDetails;
|
||||
let {
|
||||
initialDynamicParentMessage,
|
||||
initialDynamicAssistantMessage,
|
||||
user_message_id,
|
||||
frozenMessageMap,
|
||||
} = initialFetchDetails;
|
||||
|
||||
if (
|
||||
dynamicParentMessage === null &&
|
||||
dynamicAssistantMessage === null
|
||||
) {
|
||||
dynamicParentMessage = initialDynamicParentMessage;
|
||||
dynamicAssistantMessage = initialDynamicAssistantMessage;
|
||||
|
||||
dynamicParentMessage.message = currMessage;
|
||||
}
|
||||
|
||||
if (!dynamicAssistantMessage || !dynamicParentMessage) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "user_message_id")) {
|
||||
let newParentMessageId = dynamicParentMessage.messageId;
|
||||
const messageResponseIDInfo = packet as MessageResponseIDInfo;
|
||||
|
||||
for (const key in dynamicAssistantMessage) {
|
||||
(dynamicParentMessage as Record<string, any>)[key] = (
|
||||
dynamicAssistantMessage as Record<string, any>
|
||||
)[key];
|
||||
}
|
||||
|
||||
dynamicParentMessage.parentMessageId = newParentMessageId;
|
||||
dynamicParentMessage.latestChildMessageId =
|
||||
messageResponseIDInfo.reserved_assistant_message_id;
|
||||
dynamicParentMessage.childrenMessageIds = [
|
||||
messageResponseIDInfo.reserved_assistant_message_id,
|
||||
];
|
||||
|
||||
dynamicParentMessage.messageId =
|
||||
messageResponseIDInfo.user_message_id!;
|
||||
dynamicAssistantMessage = {
|
||||
messageId: messageResponseIDInfo.reserved_assistant_message_id,
|
||||
type: "assistant",
|
||||
message: "",
|
||||
documents: [],
|
||||
retrievalType: undefined,
|
||||
toolCall: null,
|
||||
files: [],
|
||||
parentMessageId: dynamicParentMessage.messageId,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: null,
|
||||
};
|
||||
}
|
||||
|
||||
setChatState((prevState) => {
|
||||
if (prevState.get(chatSessionIdRef.current!) === "loading") {
|
||||
@@ -1204,37 +1303,37 @@ export function ChatPage({
|
||||
});
|
||||
|
||||
if (Object.hasOwn(packet, "answer_piece")) {
|
||||
answer += (packet as AnswerPiecePacket).answer_piece;
|
||||
dynamicAssistantMessage.message += (
|
||||
packet as AnswerPiecePacket
|
||||
).answer_piece;
|
||||
} else if (Object.hasOwn(packet, "top_documents")) {
|
||||
documents = (packet as DocumentsResponse).top_documents;
|
||||
dynamicAssistantMessage.documents = (
|
||||
packet as DocumentsResponse
|
||||
).top_documents;
|
||||
dynamicAssistantMessage.retrievalType = RetrievalType.Search;
|
||||
retrievalType = RetrievalType.Search;
|
||||
if (documents && documents.length > 0) {
|
||||
// point to the latest message (we don't know the messageId yet, which is why
|
||||
// we have to use -1)
|
||||
setSelectedMessageForDocDisplay(user_message_id);
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "tool_name")) {
|
||||
toolCalls = [
|
||||
{
|
||||
tool_name: (packet as ToolCallMetadata).tool_name,
|
||||
tool_args: (packet as ToolCallMetadata).tool_args,
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
},
|
||||
];
|
||||
dynamicAssistantMessage.toolCall = {
|
||||
tool_name: (packet as ToolCallMetadata).tool_name,
|
||||
tool_args: (packet as ToolCallMetadata).tool_args,
|
||||
tool_result: (packet as ToolCallMetadata).tool_result,
|
||||
};
|
||||
if (
|
||||
!toolCalls[0].tool_result ||
|
||||
toolCalls[0].tool_result == undefined
|
||||
dynamicAssistantMessage.toolCall.tool_name === SEARCH_TOOL_NAME
|
||||
) {
|
||||
dynamicAssistantMessage.query =
|
||||
dynamicAssistantMessage.toolCall.tool_args.query;
|
||||
}
|
||||
|
||||
if (
|
||||
!dynamicAssistantMessage.toolCall ||
|
||||
!dynamicAssistantMessage.toolCall.tool_result ||
|
||||
dynamicAssistantMessage.toolCall.tool_result == undefined
|
||||
) {
|
||||
updateChatState("toolBuilding", frozenSessionId);
|
||||
} else {
|
||||
updateChatState("streaming", frozenSessionId);
|
||||
}
|
||||
|
||||
// This will be consolidated in upcoming tool calls udpate,
|
||||
// but for now, we need to set query as early as possible
|
||||
if (toolCalls[0].tool_name == SEARCH_TOOL_NAME) {
|
||||
query = toolCalls[0].tool_args["query"];
|
||||
}
|
||||
} else if (Object.hasOwn(packet, "file_ids")) {
|
||||
aiMessageImages = (packet as ImageGenerationDisplay).file_ids.map(
|
||||
(fileId) => {
|
||||
@@ -1244,82 +1343,54 @@ export function ChatPage({
|
||||
};
|
||||
}
|
||||
);
|
||||
dynamicAssistantMessage.files = aiMessageImages;
|
||||
} else if (Object.hasOwn(packet, "error")) {
|
||||
error = (packet as StreamingError).error;
|
||||
stackTrace = (packet as StreamingError).stack_trace;
|
||||
dynamicAssistantMessage.stackTrace = (
|
||||
packet as StreamingError
|
||||
).stack_trace;
|
||||
} else if (Object.hasOwn(packet, "message_id")) {
|
||||
finalMessage = packet as BackendMessage;
|
||||
dynamicAssistantMessage = {
|
||||
...dynamicAssistantMessage,
|
||||
...finalMessage,
|
||||
};
|
||||
} else if (Object.hasOwn(packet, "stop_reason")) {
|
||||
const stop_reason = (packet as StreamStopInfo).stop_reason;
|
||||
|
||||
if (stop_reason === StreamStopReason.CONTEXT_LENGTH) {
|
||||
updateCanContinue(true, frozenSessionId);
|
||||
}
|
||||
}
|
||||
if (!Object.hasOwn(packet, "stop_reason")) {
|
||||
updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = regenerationRequest
|
||||
? new Map([
|
||||
[
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
],
|
||||
[
|
||||
dynamicParentMessage?.messageId,
|
||||
dynamicAssistantMessage?.messageId,
|
||||
],
|
||||
] as [number, number][])
|
||||
: null;
|
||||
|
||||
// on initial message send, we insert a dummy system message
|
||||
// set this as the parent here if no parent is set
|
||||
parentMessage =
|
||||
parentMessage || frozenMessageMap?.get(SYSTEM_MESSAGE_ID)!;
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
replacementsMap: replacementsMap,
|
||||
completeMessageMapOverride: frozenMessageMap,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
};
|
||||
|
||||
const updateFn = (messages: Message[]) => {
|
||||
const replacementsMap = regenerationRequest
|
||||
? new Map([
|
||||
[
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
regenerationRequest?.parentMessage?.messageId,
|
||||
],
|
||||
[
|
||||
regenerationRequest?.messageId,
|
||||
initialFetchDetails?.assistant_message_id,
|
||||
],
|
||||
] as [number, number][])
|
||||
: null;
|
||||
|
||||
return upsertToCompleteMessageMap({
|
||||
messages: messages,
|
||||
replacementsMap: replacementsMap,
|
||||
completeMessageMapOverride: frozenMessageMap,
|
||||
chatSessionId: frozenSessionId!,
|
||||
});
|
||||
};
|
||||
|
||||
updateFn([
|
||||
{
|
||||
messageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: initialFetchDetails.user_message_id!,
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
parentMessageId: error ? null : lastSuccessfulMessageId,
|
||||
childrenMessageIds: [
|
||||
...(regenerationRequest?.parentMessage?.childrenMessageIds ||
|
||||
[]),
|
||||
initialFetchDetails.assistant_message_id!,
|
||||
],
|
||||
latestChildMessageId: initialFetchDetails.assistant_message_id,
|
||||
},
|
||||
{
|
||||
messageId: initialFetchDetails.assistant_message_id!,
|
||||
message: error || answer,
|
||||
type: error ? "error" : "assistant",
|
||||
retrievalType,
|
||||
query: finalMessage?.rephrased_query || query,
|
||||
documents:
|
||||
finalMessage?.context_docs?.top_documents || documents,
|
||||
citations: finalMessage?.citations || {},
|
||||
files: finalMessage?.files || aiMessageImages || [],
|
||||
toolCalls: finalMessage?.tool_calls || toolCalls,
|
||||
parentMessageId: regenerationRequest
|
||||
? regenerationRequest?.parentMessage?.messageId!
|
||||
: initialFetchDetails.user_message_id,
|
||||
alternateAssistantID: alternativeAssistant?.id,
|
||||
stackTrace: stackTrace,
|
||||
overridden_model: finalMessage?.overridden_model,
|
||||
stopReason: stopReason,
|
||||
},
|
||||
]);
|
||||
let { messageMap } = updateFn([
|
||||
dynamicParentMessage,
|
||||
dynamicAssistantMessage,
|
||||
]);
|
||||
frozenMessageMap = messageMap;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1333,7 +1404,7 @@ export function ChatPage({
|
||||
message: currMessage,
|
||||
type: "user",
|
||||
files: currentMessageFiles,
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId: parentMessage?.messageId || SYSTEM_MESSAGE_ID,
|
||||
},
|
||||
{
|
||||
@@ -1343,7 +1414,7 @@ export function ChatPage({
|
||||
message: errorMsg,
|
||||
type: "error",
|
||||
files: aiMessageImages || [],
|
||||
toolCalls: [],
|
||||
toolCall: null,
|
||||
parentMessageId:
|
||||
initialFetchDetails?.user_message_id || TEMP_USER_MESSAGE_ID,
|
||||
},
|
||||
@@ -1962,9 +2033,8 @@ export function ChatPage({
|
||||
completeMessageDetail
|
||||
);
|
||||
const messageReactComponentKey = `${i}-${currentSessionId()}`;
|
||||
const parentMessage = message.parentMessageId
|
||||
? messageMap.get(message.parentMessageId)
|
||||
: null;
|
||||
const parentMessage =
|
||||
i > 1 ? messageHistory[i - 1] : null;
|
||||
if (message.type === "user") {
|
||||
if (
|
||||
(currentSessionChatState == "loading" &&
|
||||
@@ -2055,6 +2125,25 @@ export function ChatPage({
|
||||
) {
|
||||
return <></>;
|
||||
}
|
||||
const mostRecentNonAIParent = messageHistory
|
||||
.slice(0, i)
|
||||
.reverse()
|
||||
.find((msg) => msg.type !== "assistant");
|
||||
|
||||
const hasChildMessage =
|
||||
message.latestChildMessageId !== null &&
|
||||
message.latestChildMessageId !== undefined;
|
||||
const childMessage = hasChildMessage
|
||||
? messageMap.get(
|
||||
message.latestChildMessageId!
|
||||
)
|
||||
: null;
|
||||
|
||||
const hasParentAI =
|
||||
parentMessage?.type == "assistant";
|
||||
const hasChildAI =
|
||||
childMessage?.type == "assistant";
|
||||
|
||||
return (
|
||||
<div
|
||||
id={`message-${message.messageId}`}
|
||||
@@ -2066,6 +2155,9 @@ export function ChatPage({
|
||||
}
|
||||
>
|
||||
<AIMessage
|
||||
setPopup={setPopup}
|
||||
hasChildAI={hasChildAI}
|
||||
hasParentAI={hasParentAI}
|
||||
continueGenerating={
|
||||
i == messageHistory.length - 1 &&
|
||||
currentCanContinue()
|
||||
@@ -2075,7 +2167,7 @@ export function ChatPage({
|
||||
overriddenModel={message.overridden_model}
|
||||
regenerate={createRegenerator({
|
||||
messageId: message.messageId,
|
||||
parentMessage: parentMessage!,
|
||||
parentMessage: mostRecentNonAIParent!,
|
||||
})}
|
||||
otherMessagesCanSwitchTo={
|
||||
parentMessage?.childrenMessageIds || []
|
||||
@@ -2112,18 +2204,15 @@ export function ChatPage({
|
||||
}
|
||||
messageId={message.messageId}
|
||||
content={message.message}
|
||||
// content={message.message}
|
||||
files={message.files}
|
||||
query={
|
||||
messageHistory[i]?.query || undefined
|
||||
}
|
||||
personaName={liveAssistant.name}
|
||||
citedDocuments={getCitedDocumentsFromMessage(
|
||||
message
|
||||
)}
|
||||
toolCall={
|
||||
message.toolCalls &&
|
||||
message.toolCalls[0]
|
||||
message.toolCall && message.toolCall
|
||||
}
|
||||
isComplete={
|
||||
i !== messageHistory.length - 1 ||
|
||||
@@ -2147,7 +2236,6 @@ export function ChatPage({
|
||||
])
|
||||
}
|
||||
handleSearchQueryEdit={
|
||||
i === messageHistory.length - 1 &&
|
||||
currentSessionChatState == "input"
|
||||
? (newQuery) => {
|
||||
if (!previousMessage) {
|
||||
@@ -2231,7 +2319,6 @@ export function ChatPage({
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
messageId={message.messageId}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{message.message}
|
||||
@@ -2279,7 +2366,6 @@ export function ChatPage({
|
||||
alternativeAssistant
|
||||
}
|
||||
messageId={null}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<div
|
||||
key={"Generating"}
|
||||
@@ -2299,7 +2385,6 @@ export function ChatPage({
|
||||
<AIMessage
|
||||
currentPersona={liveAssistant}
|
||||
messageId={-1}
|
||||
personaName={liveAssistant.name}
|
||||
content={
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{loadingError}
|
||||
|
||||
@@ -1,70 +1,78 @@
|
||||
import { FiFileText } from "react-icons/fi";
|
||||
import { useState, useRef, useEffect } from "react";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { ExpandTwoIcon } from "@/components/icons/icons";
|
||||
|
||||
export function DocumentPreview({
|
||||
fileName,
|
||||
maxWidth,
|
||||
alignBubble,
|
||||
open,
|
||||
}: {
|
||||
fileName: string;
|
||||
open?: () => void;
|
||||
maxWidth?: string;
|
||||
alignBubble?: boolean;
|
||||
}) {
|
||||
const [isOverflowing, setIsOverflowing] = useState(false);
|
||||
const fileNameRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (fileNameRef.current) {
|
||||
setIsOverflowing(
|
||||
fileNameRef.current.scrollWidth > fileNameRef.current.clientWidth
|
||||
);
|
||||
}
|
||||
}, [fileName]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`
|
||||
${alignBubble && "w-64"}
|
||||
flex
|
||||
items-center
|
||||
p-2
|
||||
p-3
|
||||
bg-hover
|
||||
border
|
||||
border-border
|
||||
rounded-md
|
||||
rounded-lg
|
||||
box-border
|
||||
h-16
|
||||
h-20
|
||||
hover:shadow-sm
|
||||
transition-all
|
||||
`}
|
||||
>
|
||||
<div className="flex-shrink-0">
|
||||
<div
|
||||
className="
|
||||
w-12
|
||||
h-12
|
||||
w-14
|
||||
h-14
|
||||
bg-document
|
||||
flex
|
||||
items-center
|
||||
justify-center
|
||||
rounded-md
|
||||
rounded-lg
|
||||
transition-all
|
||||
duration-200
|
||||
hover:bg-document-dark
|
||||
"
|
||||
>
|
||||
<FiFileText className="w-6 h-6 text-white" />
|
||||
<FiFileText className="w-7 h-7 text-white" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="ml-4 relative">
|
||||
<div className="ml-4 flex-grow">
|
||||
<Tooltip content={fileName} side="top" align="start">
|
||||
<div
|
||||
ref={fileNameRef}
|
||||
className={`font-medium text-sm line-clamp-1 break-all ellipses ${
|
||||
className={`font-medium text-sm line-clamp-1 break-all ellipsis ${
|
||||
maxWidth ? maxWidth : "max-w-48"
|
||||
}`}
|
||||
>
|
||||
{fileName}
|
||||
</div>
|
||||
</Tooltip>
|
||||
<div className="text-subtle text-sm">Document</div>
|
||||
<div className="text-subtle text-xs mt-1">Document</div>
|
||||
</div>
|
||||
{open && (
|
||||
<button
|
||||
onClick={() => open()}
|
||||
className="ml-2 p-2 rounded-full hover:bg-gray-200 transition-colors duration-200"
|
||||
aria-label="Expand document"
|
||||
>
|
||||
<ExpandTwoIcon className="w-5 h-5 text-gray-600" />
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import {
|
||||
SearchDanswerDocument,
|
||||
StreamStopReason,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { GraphChunk } from "./message/Messages";
|
||||
|
||||
export enum RetrievalType {
|
||||
None = "none",
|
||||
@@ -32,6 +33,7 @@ export enum ChatFileType {
|
||||
IMAGE = "image",
|
||||
DOCUMENT = "document",
|
||||
PLAIN_TEXT = "plain_text",
|
||||
CSV = "csv",
|
||||
}
|
||||
|
||||
export interface FileDescriptor {
|
||||
@@ -85,7 +87,7 @@ export interface Message {
|
||||
documents?: DanswerDocument[] | null;
|
||||
citations?: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
toolCalls: ToolCallMetadata[];
|
||||
toolCall: ToolCallMetadata | null;
|
||||
// for rebuilding the message tree
|
||||
parentMessageId: number | null;
|
||||
childrenMessageIds?: number[];
|
||||
@@ -94,6 +96,7 @@ export interface Message {
|
||||
stackTrace?: string | null;
|
||||
overridden_model?: string;
|
||||
stopReason?: StreamStopReason | null;
|
||||
graphs?: GraphChunk[];
|
||||
}
|
||||
|
||||
export interface BackendChatSession {
|
||||
@@ -120,7 +123,7 @@ export interface BackendMessage {
|
||||
time_sent: string;
|
||||
citations: CitationMap;
|
||||
files: FileDescriptor[];
|
||||
tool_calls: ToolCallFinalResult[];
|
||||
tool_call: ToolCallFinalResult | null;
|
||||
alternate_assistant_id?: number | null;
|
||||
overridden_model?: string;
|
||||
}
|
||||
@@ -143,3 +146,10 @@ export interface StreamingError {
|
||||
error: string;
|
||||
stack_trace: string;
|
||||
}
|
||||
|
||||
export interface ImageGenerationResult {
|
||||
revised_prompt: string;
|
||||
url: string;
|
||||
}
|
||||
|
||||
export type ImageGenerationResults = ImageGenerationResult[];
|
||||
|
||||
@@ -60,6 +60,7 @@ export function getChatRetentionInfo(
|
||||
showRetentionWarning,
|
||||
};
|
||||
}
|
||||
import { GraphChunk } from "./message/Messages";
|
||||
|
||||
export async function updateModelOverrideForChatSession(
|
||||
chatSessionId: number,
|
||||
@@ -113,7 +114,8 @@ export type PacketType =
|
||||
| ImageGenerationDisplay
|
||||
| StreamingError
|
||||
| MessageResponseIDInfo
|
||||
| StreamStopInfo;
|
||||
| StreamStopInfo
|
||||
| GraphChunk;
|
||||
|
||||
export async function* sendMessage({
|
||||
regenerate,
|
||||
@@ -435,7 +437,7 @@ export function processRawChatHistory(
|
||||
citations: messageInfo?.citations || {},
|
||||
}
|
||||
: {}),
|
||||
toolCalls: messageInfo.tool_calls,
|
||||
toolCall: messageInfo.tool_call,
|
||||
parentMessageId: messageInfo.parent_message,
|
||||
childrenMessageIds: [],
|
||||
latestChildMessageId: messageInfo.latest_child_message,
|
||||
@@ -479,6 +481,7 @@ export function buildLatestMessageChain(
|
||||
let currMessage: Message | null = rootMessage;
|
||||
while (currMessage) {
|
||||
finalMessageList.push(currMessage);
|
||||
|
||||
const childMessageNumber = currMessage.latestChildMessageId;
|
||||
if (childMessageNumber && messageMap.has(childMessageNumber)) {
|
||||
currMessage = messageMap.get(childMessageNumber) as Message;
|
||||
|
||||
81
web/src/app/chat/message/JSONUpload.tsx
Normal file
81
web/src/app/chat/message/JSONUpload.tsx
Normal file
@@ -0,0 +1,81 @@
|
||||
import React, { useState } from "react";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { Button } from "@/components/Button";
|
||||
import { uploadFile } from "@/app/admin/assistants/lib";
|
||||
|
||||
interface JSONUploadProps {
|
||||
onUploadSuccess: (jsonData: any) => void;
|
||||
}
|
||||
|
||||
export const JSONUpload: React.FC<JSONUploadProps> = ({ onUploadSuccess }) => {
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [credentialJsonStr, setCredentialJsonStr] = useState<
|
||||
string | undefined
|
||||
>();
|
||||
const [file, setFile] = useState<File | null>(null);
|
||||
|
||||
const uploadJSON = async () => {
|
||||
if (!file) {
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: "Please select a file to upload.",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
let parsedData;
|
||||
if (file.type === "application/json") {
|
||||
parsedData = JSON.parse(credentialJsonStr!);
|
||||
} else {
|
||||
parsedData = credentialJsonStr;
|
||||
}
|
||||
|
||||
const response = await uploadFile(file);
|
||||
console.log(response);
|
||||
|
||||
onUploadSuccess(parsedData);
|
||||
setPopup({
|
||||
type: "success",
|
||||
message: "File uploaded successfully!",
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Error uploading file:", error);
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to upload file - ${error}`,
|
||||
});
|
||||
}
|
||||
};
|
||||
// 155056ca-dade-4825-bac9-efe86e7bda54
|
||||
return (
|
||||
<div>
|
||||
{popup}
|
||||
<input
|
||||
className="mr-3 text-sm text-gray-900 border border-gray-300 rounded-lg cursor-pointer bg-background dark:text-gray-400 focus:outline-none dark:bg-gray-700 dark:border-gray-600 dark:placeholder-gray-400"
|
||||
type="file"
|
||||
onChange={(event) => {
|
||||
if (!event.target.files) {
|
||||
return;
|
||||
}
|
||||
const file = event.target.files[0];
|
||||
setFile(file);
|
||||
const reader = new FileReader();
|
||||
|
||||
reader.onload = function (loadEvent) {
|
||||
if (!loadEvent?.target?.result) {
|
||||
return;
|
||||
}
|
||||
const fileContents = loadEvent.target.result;
|
||||
setCredentialJsonStr(fileContents as string);
|
||||
};
|
||||
|
||||
reader.readAsText(file);
|
||||
}}
|
||||
/>
|
||||
<Button disabled={!credentialJsonStr} onClick={uploadJSON}>
|
||||
Upload JSON
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,6 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
FiImage,
|
||||
FiEdit2,
|
||||
FiChevronRight,
|
||||
FiChevronLeft,
|
||||
@@ -8,25 +9,22 @@ import {
|
||||
FiGlobe,
|
||||
} from "react-icons/fi";
|
||||
import { FeedbackType } from "../types";
|
||||
import {
|
||||
Dispatch,
|
||||
SetStateAction,
|
||||
useContext,
|
||||
useEffect,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import {
|
||||
DanswerDocument,
|
||||
FilteredDanswerDocument,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { SearchSummary } from "./SearchSummary";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { SkippedSearch } from "./SkippedSearch";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import { CopyButton } from "@/components/CopyButton";
|
||||
import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
|
||||
import {
|
||||
ChatFileType,
|
||||
FileDescriptor,
|
||||
ImageGenerationResults,
|
||||
ToolCallMetadata,
|
||||
} from "../interfaces";
|
||||
import {
|
||||
IMAGE_GENERATION_TOOL_NAME,
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -44,13 +42,11 @@ import "./custom-code-styles.css";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { AssistantIcon } from "@/components/assistants/AssistantIcon";
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
|
||||
import {
|
||||
ThumbsUpIcon,
|
||||
ThumbsDownIcon,
|
||||
LikeFeedback,
|
||||
DislikeFeedback,
|
||||
ToolCallIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import {
|
||||
CustomTooltip,
|
||||
@@ -59,12 +55,23 @@ import {
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { useMouseTracking } from "./hooks";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import GeneratingImageDisplay from "../tools/GeneratingImageDisplay";
|
||||
import RegenerateOption from "../RegenerateOption";
|
||||
import { LlmOverride } from "@/lib/hooks";
|
||||
import { ContinueGenerating } from "./ContinueMessage";
|
||||
import DualPromptDisplay from "../tools/ImagePromptCitaiton";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { Popover } from "@/components/popover/Popover";
|
||||
import { DISABLED_CSV_DISPLAY } from "@/lib/constants";
|
||||
import {
|
||||
LineChartDisplay,
|
||||
ModalChartWrapper,
|
||||
} from "../../../components/chat_display/graphs/LineChartDisplay";
|
||||
import BarChartDisplay from "@/components/chat_display/graphs/BarChart";
|
||||
import ToolResult, {
|
||||
FileWrapper,
|
||||
} from "@/components/chat_display/InteractiveToolResult";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -80,8 +87,12 @@ function FileDisplay({
|
||||
alignBubble?: boolean;
|
||||
}) {
|
||||
const imageFiles = files.filter((file) => file.type === ChatFileType.IMAGE);
|
||||
const nonImgFiles = files.filter((file) => file.type !== ChatFileType.IMAGE);
|
||||
const nonImgFiles = files.filter(
|
||||
(file) => file.type !== ChatFileType.IMAGE && file.type !== ChatFileType.CSV
|
||||
);
|
||||
const csvImgFiles = files.filter((file) => file.type == ChatFileType.CSV);
|
||||
|
||||
const [close, setClose] = useState(true);
|
||||
return (
|
||||
<>
|
||||
{nonImgFiles && nonImgFiles.length > 0 && (
|
||||
@@ -104,6 +115,36 @@ function FileDisplay({
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{csvImgFiles && csvImgFiles.length > 0 && (
|
||||
<div className={` ${alignBubble && "ml-auto"} mt-2 auto mb-4`}>
|
||||
<div className="flex flex-col gap-2">
|
||||
{csvImgFiles.map((file) => {
|
||||
return (
|
||||
<div key={file.id} className="w-fit">
|
||||
{close && !DISABLED_CSV_DISPLAY ? (
|
||||
<>
|
||||
<ToolResult
|
||||
csvFileDescriptor={file}
|
||||
close={() => setClose(false)}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<DocumentPreview
|
||||
open={
|
||||
DISABLED_CSV_DISPLAY ? undefined : () => setClose(true)
|
||||
}
|
||||
fileName={file.name || file.id}
|
||||
maxWidth="max-w-64"
|
||||
alignBubble={alignBubble}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{/* <LineChartDisplay /> */}
|
||||
{imageFiles && imageFiles.length > 0 && (
|
||||
<div
|
||||
id="danswer-image"
|
||||
@@ -120,7 +161,20 @@ function FileDisplay({
|
||||
);
|
||||
}
|
||||
|
||||
enum GraphType {
|
||||
BAR_CHART = "bar_chart",
|
||||
LINE_GRAPH = "line_graph",
|
||||
}
|
||||
|
||||
export interface GraphChunk {
|
||||
file_id: string;
|
||||
plot_data: Record<string, any> | null;
|
||||
graph_type: GraphType | null;
|
||||
}
|
||||
|
||||
export const AIMessage = ({
|
||||
hasChildAI,
|
||||
hasParentAI,
|
||||
regenerate,
|
||||
overriddenModel,
|
||||
continueGenerating,
|
||||
@@ -129,12 +183,12 @@ export const AIMessage = ({
|
||||
toggleDocumentSelection,
|
||||
alternativeAssistant,
|
||||
docs,
|
||||
graphs = [],
|
||||
messageId,
|
||||
content,
|
||||
files,
|
||||
selectedDocuments,
|
||||
query,
|
||||
personaName,
|
||||
citedDocuments,
|
||||
toolCall,
|
||||
isComplete,
|
||||
@@ -148,8 +202,12 @@ export const AIMessage = ({
|
||||
currentPersona,
|
||||
otherMessagesCanSwitchTo,
|
||||
onMessageSelection,
|
||||
setPopup,
|
||||
}: {
|
||||
shared?: boolean;
|
||||
hasChildAI?: boolean;
|
||||
hasParentAI?: boolean;
|
||||
graphs?: GraphChunk[];
|
||||
isActive?: boolean;
|
||||
continueGenerating?: () => void;
|
||||
otherMessagesCanSwitchTo?: number[];
|
||||
@@ -163,9 +221,8 @@ export const AIMessage = ({
|
||||
content: string | JSX.Element;
|
||||
files?: FileDescriptor[];
|
||||
query?: string;
|
||||
personaName?: string;
|
||||
citedDocuments?: [string, DanswerDocument][] | null;
|
||||
toolCall?: ToolCallMetadata;
|
||||
toolCall?: ToolCallMetadata | null;
|
||||
isComplete?: boolean;
|
||||
hasDocs?: boolean;
|
||||
handleFeedback?: (feedbackType: FeedbackType) => void;
|
||||
@@ -176,7 +233,11 @@ export const AIMessage = ({
|
||||
retrievalDisabled?: boolean;
|
||||
overriddenModel?: string;
|
||||
regenerate?: (modelOverRide: LlmOverride) => Promise<void>;
|
||||
setPopup?: (popupSpec: PopupSpec | null) => void;
|
||||
}) => {
|
||||
console.log(toolCall);
|
||||
const [isPopoverOpen, setIsPopoverOpen] = useState(false);
|
||||
|
||||
const toolCallGenerating = toolCall && !toolCall.tool_result;
|
||||
const processContent = (content: string | JSX.Element) => {
|
||||
if (typeof content !== "string") {
|
||||
@@ -199,12 +260,20 @@ export const AIMessage = ({
|
||||
return content;
|
||||
}
|
||||
}
|
||||
if (
|
||||
isComplete &&
|
||||
toolCall?.tool_result &&
|
||||
toolCall.tool_name == IMAGE_GENERATION_TOOL_NAME
|
||||
) {
|
||||
return content + ` [${toolCall.tool_name}]()`;
|
||||
}
|
||||
|
||||
return content + (!isComplete && !toolCallGenerating ? " [*]() " : "");
|
||||
};
|
||||
const finalContent = processContent(content as string);
|
||||
|
||||
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
|
||||
|
||||
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
|
||||
|
||||
const settings = useContext(SettingsContext);
|
||||
@@ -274,39 +343,50 @@ export const AIMessage = ({
|
||||
<div
|
||||
id="danswer-ai-message"
|
||||
ref={trackedElementRef}
|
||||
className={"py-5 ml-4 px-5 relative flex "}
|
||||
className={`${hasParentAI ? "pb-5" : "py-5"} px-2 lg:px-5 relative flex `}
|
||||
>
|
||||
<div
|
||||
className={`mx-auto ${shared ? "w-full" : "w-[90%]"} max-w-message-max`}
|
||||
>
|
||||
<div className={`desktop:mr-12 ${!shared && "mobile:ml-0 md:ml-8"}`}>
|
||||
<div className="flex">
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
{!hasParentAI ? (
|
||||
<AssistantIcon
|
||||
size="small"
|
||||
assistant={alternativeAssistant || currentPersona}
|
||||
/>
|
||||
) : (
|
||||
<div className="w-6" />
|
||||
)}
|
||||
|
||||
<div className="w-full">
|
||||
<div className="max-w-message-max break-words">
|
||||
<div className="w-full ml-4">
|
||||
<div className="max-w-message-max break-words">
|
||||
{!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME ? (
|
||||
{(!toolCall || toolCall.tool_name === SEARCH_TOOL_NAME) && (
|
||||
<>
|
||||
{query !== undefined &&
|
||||
handleShowRetrieved !== undefined &&
|
||||
isCurrentlyShowingRetrieved !== undefined &&
|
||||
!retrievalDisabled && (
|
||||
<div className="mb-1">
|
||||
<SearchSummary
|
||||
docs={docs}
|
||||
filteredDocs={filteredDocs}
|
||||
query={query}
|
||||
finished={toolCall?.tool_result != undefined}
|
||||
hasDocs={hasDocs || false}
|
||||
messageId={messageId}
|
||||
handleShowRetrieved={handleShowRetrieved}
|
||||
finished={
|
||||
toolCall?.tool_result != undefined ||
|
||||
isComplete!
|
||||
}
|
||||
toggleDocumentSelection={
|
||||
toggleDocumentSelection
|
||||
}
|
||||
handleSearchQueryEdit={handleSearchQueryEdit}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{handleForceSearch &&
|
||||
!hasChildAI &&
|
||||
content &&
|
||||
query === undefined &&
|
||||
!hasDocs &&
|
||||
@@ -318,7 +398,7 @@ export const AIMessage = ({
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
) : null}
|
||||
)}
|
||||
|
||||
{toolCall &&
|
||||
!TOOLS_WITH_CUSTOM_HANDLING.includes(
|
||||
@@ -344,21 +424,53 @@ export const AIMessage = ({
|
||||
|
||||
{toolCall &&
|
||||
toolCall.tool_name === INTERNET_SEARCH_TOOL_NAME && (
|
||||
<ToolRunDisplay
|
||||
toolName={
|
||||
toolCall.tool_result
|
||||
? `Searched the internet`
|
||||
: `Searching the internet`
|
||||
}
|
||||
toolLogo={
|
||||
<FiGlobe size={15} className="my-auto mr-1" />
|
||||
}
|
||||
isRunning={!toolCall.tool_result}
|
||||
/>
|
||||
<div className="my-2">
|
||||
<ToolRunDisplay
|
||||
toolName={
|
||||
toolCall.tool_result
|
||||
? `Searched the internet`
|
||||
: `Searching the internet`
|
||||
}
|
||||
toolLogo={
|
||||
<FiGlobe size={15} className="my-auto mr-1" />
|
||||
}
|
||||
isRunning={!toolCall.tool_result}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{graphs.map((graph, ind) => {
|
||||
return graph.graph_type === GraphType.LINE_GRAPH ? (
|
||||
<ModalChartWrapper
|
||||
key={ind}
|
||||
chartType="line"
|
||||
fileId={graph.file_id}
|
||||
>
|
||||
<LineChartDisplay fileId={graph.file_id} />
|
||||
</ModalChartWrapper>
|
||||
) : (
|
||||
<ModalChartWrapper
|
||||
key={ind}
|
||||
chartType="bar"
|
||||
fileId={graph.file_id}
|
||||
>
|
||||
<BarChartDisplay fileId={graph.file_id} />
|
||||
</ModalChartWrapper>
|
||||
);
|
||||
})}
|
||||
|
||||
{content || files ? (
|
||||
<>
|
||||
{toolCall?.tool_name == "create_graph" && (
|
||||
<ModalChartWrapper
|
||||
key={0}
|
||||
chartType="line"
|
||||
fileId={toolCall?.tool_result?.file_id}
|
||||
>
|
||||
<LineChartDisplay
|
||||
fileId={toolCall?.tool_result?.file_id}
|
||||
/>
|
||||
</ModalChartWrapper>
|
||||
)}
|
||||
<FileDisplay files={files || []} />
|
||||
|
||||
{typeof content === "string" ? (
|
||||
@@ -371,6 +483,50 @@ export const AIMessage = ({
|
||||
const { node, ...rest } = props;
|
||||
const value = rest.children;
|
||||
|
||||
if (
|
||||
value
|
||||
?.toString()
|
||||
.startsWith(IMAGE_GENERATION_TOOL_NAME)
|
||||
) {
|
||||
const imageGenerationResult =
|
||||
toolCall?.tool_result as ImageGenerationResults;
|
||||
|
||||
return (
|
||||
<Popover
|
||||
open={isPopoverOpen}
|
||||
onOpenChange={
|
||||
() => null
|
||||
// setIsPopoverOpen(isPopoverOpen => !isPopoverOpen)
|
||||
} // only allow closing from the icon
|
||||
content={
|
||||
<button
|
||||
onMouseDown={() => {
|
||||
setIsPopoverOpen(!isPopoverOpen);
|
||||
}}
|
||||
>
|
||||
<ToolCallIcon className="cursor-pointer flex-none text-blue-500 hover:text-blue-700 !h-4 !w-4 inline-block" />
|
||||
</button>
|
||||
}
|
||||
popover={
|
||||
<DualPromptDisplay
|
||||
arg="Prompt"
|
||||
setPopup={setPopup!}
|
||||
prompt1={
|
||||
imageGenerationResult[0]
|
||||
.revised_prompt
|
||||
}
|
||||
prompt2={
|
||||
imageGenerationResult[1]
|
||||
.revised_prompt
|
||||
}
|
||||
/>
|
||||
}
|
||||
side="top"
|
||||
align="center"
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (value?.toString().startsWith("*")) {
|
||||
return (
|
||||
<div className="flex-none bg-background-800 inline-block rounded-full h-3 w-3 ml-2" />
|
||||
@@ -378,10 +534,6 @@ export const AIMessage = ({
|
||||
} else if (
|
||||
value?.toString().startsWith("[")
|
||||
) {
|
||||
// for some reason <a> tags cause the onClick to not apply
|
||||
// and the links are unclickable
|
||||
// TODO: fix the fact that you have to double click to follow link
|
||||
// for the first link
|
||||
return (
|
||||
<Citation link={rest?.href}>
|
||||
{rest.children}
|
||||
@@ -428,82 +580,22 @@ export const AIMessage = ({
|
||||
) : isComplete ? null : (
|
||||
<></>
|
||||
)}
|
||||
{isComplete && docs && docs.length > 0 && (
|
||||
<div className="mt-2 -mx-8 w-full mb-4 flex relative">
|
||||
<div className="w-full">
|
||||
<div className="px-8 flex gap-x-2">
|
||||
{!settings?.isMobile &&
|
||||
filteredDocs.length > 0 &&
|
||||
filteredDocs.slice(0, 2).map((doc, ind) => (
|
||||
<div
|
||||
key={doc.document_id}
|
||||
className={`w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 pb-2 pt-1 border-b
|
||||
`}
|
||||
>
|
||||
<a
|
||||
href={doc.link || undefined}
|
||||
target="_blank"
|
||||
className="text-sm flex w-full pt-1 gap-x-1.5 overflow-hidden justify-between font-semibold text-text-700"
|
||||
>
|
||||
<Citation link={doc.link} index={ind + 1} />
|
||||
<p className="shrink truncate ellipsis break-all">
|
||||
{doc.semantic_identifier ||
|
||||
doc.document_id}
|
||||
</p>
|
||||
<div className="ml-auto flex-none">
|
||||
{doc.is_internet ? (
|
||||
<InternetSearchIcon url={doc.link} />
|
||||
) : (
|
||||
<SourceIcon
|
||||
sourceType={doc.source_type}
|
||||
iconSize={18}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</a>
|
||||
<div className="flex overscroll-x-scroll mt-.5">
|
||||
<DocumentMetadataBlock document={doc} />
|
||||
</div>
|
||||
<div className="line-clamp-3 text-xs break-words pt-1">
|
||||
{doc.blurb}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
<div
|
||||
onClick={() => {
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={-1}
|
||||
className="cursor-pointer w-[200px] rounded-lg flex-none transition-all duration-500 hover:bg-background-125 bg-text-100 px-4 py-2 border-b"
|
||||
>
|
||||
<div className="text-sm flex justify-between font-semibold text-text-700">
|
||||
<p className="line-clamp-1">See context</p>
|
||||
<div className="flex gap-x-1">
|
||||
{uniqueSources.map((sourceType, ind) => {
|
||||
return (
|
||||
<div key={ind} className="flex-none">
|
||||
<SourceIcon
|
||||
sourceType={sourceType}
|
||||
iconSize={18}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
<div className="line-clamp-3 text-xs break-words pt-1">
|
||||
See more
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
{/* <ModalChartWrapper chartType="line" fileId="fee2ff90-4ebe-43fc-858f-a95c73385da4" >
|
||||
<LineChartDisplay fileId="fee2ff90-4ebe-43fc-858f-a95c73385da4" />
|
||||
</ModalChartWrapper> */}
|
||||
{/*
|
||||
<ModalChartWrapper chartType="bar" fileId={"0ad36971-9353-42de-b89d-9c3361d3c3eb"}>
|
||||
<BarChartDisplay fileId={"0ad36971-9353-42de-b89d-9c3361d3c3eb"} />
|
||||
</ModalChartWrapper>
|
||||
|
||||
{handleFeedback &&
|
||||
<ModalChartWrapper chartType="other" fileId={"066fc31f-56f0-48fb-98d3-ffd46f1ac0f5"}>
|
||||
<ImageDisplay fileId={"066fc31f-56f0-48fb-98d3-ffd46f1ac0f5"} />
|
||||
</ModalChartWrapper>
|
||||
*/}
|
||||
|
||||
{!hasChildAI &&
|
||||
handleFeedback &&
|
||||
(isActive ? (
|
||||
<div
|
||||
className={`
|
||||
|
||||
@@ -4,9 +4,21 @@ import {
|
||||
} from "@/components/BasicClickable";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { InternetSearchIcon } from "@/components/InternetSearchIcon";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { ChevronDownIcon, InfoIcon } from "@/components/icons/icons";
|
||||
import { DocumentMetadataBlock } from "@/components/search/DocumentDisplay";
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { Tooltip } from "@/components/tooltip/Tooltip";
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import {
|
||||
DanswerDocument,
|
||||
FilteredDanswerDocument,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import { FiCheck, FiEdit2, FiSearch, FiX } from "react-icons/fi";
|
||||
import { DownChevron } from "react-select/dist/declarations/src/components/indicators";
|
||||
|
||||
export function ShowHideDocsButton({
|
||||
messageId,
|
||||
@@ -37,24 +49,31 @@ export function ShowHideDocsButton({
|
||||
|
||||
export function SearchSummary({
|
||||
query,
|
||||
hasDocs,
|
||||
filteredDocs,
|
||||
finished,
|
||||
messageId,
|
||||
handleShowRetrieved,
|
||||
docs,
|
||||
toggleDocumentSelection,
|
||||
handleSearchQueryEdit,
|
||||
}: {
|
||||
toggleDocumentSelection?: () => void;
|
||||
docs?: DanswerDocument[] | null;
|
||||
filteredDocs: FilteredDanswerDocument[];
|
||||
finished: boolean;
|
||||
query: string;
|
||||
hasDocs: boolean;
|
||||
messageId: number | null;
|
||||
handleShowRetrieved: (messageId: number | null) => void;
|
||||
handleSearchQueryEdit?: (query: string) => void;
|
||||
}) {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [finalQuery, setFinalQuery] = useState(query);
|
||||
const [isOverflowed, setIsOverflowed] = useState(false);
|
||||
const searchingForRef = useRef<HTMLDivElement>(null);
|
||||
const editQueryRef = useRef<HTMLInputElement>(null);
|
||||
const [isDropdownOpen, setIsDropdownOpen] = useState(false);
|
||||
const searchingForRef = useRef<HTMLDivElement | null>(null);
|
||||
const editQueryRef = useRef<HTMLInputElement | null>(null);
|
||||
|
||||
const settings = useContext(SettingsContext);
|
||||
|
||||
const toggleDropdown = () => {
|
||||
setIsDropdownOpen(!isDropdownOpen);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const checkOverflow = () => {
|
||||
@@ -68,7 +87,7 @@ export function SearchSummary({
|
||||
};
|
||||
|
||||
checkOverflow();
|
||||
window.addEventListener("resize", checkOverflow); // Recheck on window resize
|
||||
window.addEventListener("resize", checkOverflow);
|
||||
|
||||
return () => window.removeEventListener("resize", checkOverflow);
|
||||
}, []);
|
||||
@@ -86,15 +105,30 @@ export function SearchSummary({
|
||||
}, [query]);
|
||||
|
||||
const searchingForDisplay = (
|
||||
<div className={`flex p-1 rounded ${isOverflowed && "cursor-default"}`}>
|
||||
<FiSearch className="flex-none mr-2 my-auto" size={14} />
|
||||
<div
|
||||
className={`${!finished && "loading-text"}
|
||||
!text-sm !line-clamp-1 !break-all px-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
{finished ? "Searched" : "Searching"} for: <i> {finalQuery}</i>
|
||||
</div>
|
||||
<div
|
||||
className={`flex my-auto items-center ${isOverflowed && "cursor-default"}`}
|
||||
>
|
||||
{finished ? (
|
||||
<>
|
||||
<div
|
||||
onClick={() => {
|
||||
toggleDropdown();
|
||||
}}
|
||||
className={`transition-colors duration-300 group-hover:text-text-toolhover cursor-pointer text-text-toolrun !line-clamp-1 !break-all pr-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
Searched {filteredDocs.length > 0 && filteredDocs.length} document
|
||||
{filteredDocs.length != 1 && "s"} for {query}
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div
|
||||
className={`loading-text !text-sm !line-clamp-1 !break-all px-0.5`}
|
||||
ref={searchingForRef}
|
||||
>
|
||||
Searching for: <i> {finalQuery}</i>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -145,43 +179,126 @@ export function SearchSummary({
|
||||
</div>
|
||||
</div>
|
||||
) : null;
|
||||
const SearchBlock = ({ doc, ind }: { doc: DanswerDocument; ind: number }) => {
|
||||
return (
|
||||
<div
|
||||
onClick={() => {
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={doc.document_id}
|
||||
className={`flex items-start gap-3 px-4 py-3 text-token-text-secondary ${ind == 0 && "rounded-t-xl"} hover:bg-background-100 group relative text-sm`}
|
||||
>
|
||||
<div className="mt-1 scale-[.9] flex-none">
|
||||
{doc.is_internet ? (
|
||||
<InternetSearchIcon url={doc.link} />
|
||||
) : (
|
||||
<SourceIcon sourceType={doc.source_type} iconSize={18} />
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-col">
|
||||
<a
|
||||
href={doc.link}
|
||||
target="_blank"
|
||||
className="line-clamp-1 text-text-900"
|
||||
>
|
||||
<p className="shrink truncate ellipsis break-all ">
|
||||
{doc.semantic_identifier || doc.document_id}
|
||||
</p>
|
||||
<p className="line-clamp-3 text-text-500 break-words">
|
||||
{doc.blurb}
|
||||
</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
{isEditing ? (
|
||||
editInput
|
||||
) : (
|
||||
<>
|
||||
<div className="text-sm">
|
||||
{isOverflowed ? (
|
||||
<HoverPopup
|
||||
mainContent={searchingForDisplay}
|
||||
popupContent={
|
||||
<div>
|
||||
<b>Full query:</b>{" "}
|
||||
<div className="mt-1 italic">{query}</div>
|
||||
</div>
|
||||
}
|
||||
direction="top"
|
||||
<>
|
||||
<div className="flex gap-x-2 group">
|
||||
{isEditing ? (
|
||||
editInput
|
||||
) : (
|
||||
<>
|
||||
<div className="my-auto text-sm">
|
||||
{isOverflowed ? (
|
||||
<HoverPopup
|
||||
mainContent={searchingForDisplay}
|
||||
popupContent={
|
||||
<div>
|
||||
<b>Full query:</b>{" "}
|
||||
<div className="mt-1 italic">{query}</div>
|
||||
</div>
|
||||
}
|
||||
direction="top"
|
||||
/>
|
||||
) : (
|
||||
searchingForDisplay
|
||||
)}
|
||||
</div>
|
||||
|
||||
<button
|
||||
className="my-auto invisible group-hover:visible transition-all duration-300 rounded"
|
||||
onClick={toggleDropdown}
|
||||
>
|
||||
<ChevronDownIcon
|
||||
className={`transform transition-transform ${isDropdownOpen ? "rotate-180" : ""}`}
|
||||
/>
|
||||
</button>
|
||||
|
||||
{handleSearchQueryEdit ? (
|
||||
<Tooltip delayDuration={1000} content={"Edit Search"}>
|
||||
<button
|
||||
className="my-auto invisible group-hover:visible transition-all duration-300 cursor-pointer rounded"
|
||||
onClick={() => {
|
||||
setIsEditing(true);
|
||||
}}
|
||||
>
|
||||
<FiEdit2 />
|
||||
</button>
|
||||
</Tooltip>
|
||||
) : (
|
||||
searchingForDisplay
|
||||
<></>
|
||||
)}
|
||||
</div>
|
||||
{handleSearchQueryEdit && (
|
||||
<Tooltip delayDuration={1000} content={"Edit Search"}>
|
||||
<button
|
||||
className="my-auto hover:bg-hover p-1.5 rounded"
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isDropdownOpen && docs && docs.length > 0 && (
|
||||
<div
|
||||
className={`mt-2 -mx-8 w-full mb-4 flex relative transition-all duration-300 ${isDropdownOpen ? "opacity-100 max-h-[1000px]" : "opacity-0 max-h-0"}`}
|
||||
>
|
||||
<div className="w-full">
|
||||
<div className="mx-8 flex rounded max-h-[500px] overflow-y-scroll rounded-lg border-1.5 border divide-y divider-y-1.5 divider-y-border border-border flex-col gap-x-4">
|
||||
{!settings?.isMobile &&
|
||||
filteredDocs.length > 0 &&
|
||||
filteredDocs.map((doc, ind) => (
|
||||
<SearchBlock key={ind} doc={doc} ind={ind} />
|
||||
))}
|
||||
|
||||
<div
|
||||
onClick={() => {
|
||||
setIsEditing(true);
|
||||
if (toggleDocumentSelection) {
|
||||
toggleDocumentSelection();
|
||||
}
|
||||
}}
|
||||
key={-1}
|
||||
className="cursor-pointer w-full flex transition-all duration-500 hover:bg-background-100 py-3 border-b"
|
||||
>
|
||||
<FiEdit2 />
|
||||
</button>
|
||||
</Tooltip>
|
||||
)}
|
||||
</>
|
||||
<div key={0} className="px-3 invisible scale-[.9] flex-none">
|
||||
<SourceIcon sourceType={"file"} iconSize={18} />
|
||||
</div>
|
||||
<div className="text-sm flex justify-between text-text-900">
|
||||
<p className="line-clamp-1">See context</p>
|
||||
<div className="flex gap-x-1"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -101,7 +101,6 @@ export function SharedChatDisplay({
|
||||
messageId={message.messageId}
|
||||
content={message.message}
|
||||
files={message.files || []}
|
||||
personaName={chatSession.persona_name}
|
||||
citedDocuments={getCitedDocumentsFromMessage(message)}
|
||||
isComplete
|
||||
/>
|
||||
|
||||
83
web/src/app/chat/tools/ImagePromptCitaiton.tsx
Normal file
83
web/src/app/chat/tools/ImagePromptCitaiton.tsx
Normal file
@@ -0,0 +1,83 @@
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { CopyIcon } from "@/components/icons/icons";
|
||||
import { Divider } from "@tremor/react";
|
||||
import React, { forwardRef, useState } from "react";
|
||||
import { FiCheck } from "react-icons/fi";
|
||||
|
||||
interface PromptDisplayProps {
|
||||
prompt1: string;
|
||||
prompt2?: string;
|
||||
arg: string;
|
||||
setPopup: (popupSpec: PopupSpec | null) => void;
|
||||
}
|
||||
|
||||
const DualPromptDisplay = forwardRef<HTMLDivElement, PromptDisplayProps>(
|
||||
({ prompt1, prompt2, setPopup, arg }, ref) => {
|
||||
const [copied, setCopied] = useState<number | null>(null);
|
||||
|
||||
const copyToClipboard = (text: string, index: number) => {
|
||||
navigator.clipboard
|
||||
.writeText(text)
|
||||
.then(() => {
|
||||
setPopup({ message: "Copied to clipboard", type: "success" });
|
||||
setCopied(index);
|
||||
setTimeout(() => setCopied(null), 2000); // Reset copy status after 2 seconds
|
||||
})
|
||||
.catch((err) => {
|
||||
setPopup({ message: "Failed to copy", type: "error" });
|
||||
});
|
||||
};
|
||||
|
||||
const PromptSection = ({
|
||||
copied,
|
||||
prompt,
|
||||
index,
|
||||
}: {
|
||||
copied: number | null;
|
||||
prompt: string;
|
||||
index: number;
|
||||
}) => (
|
||||
<div className="w-full p-2 rounded-lg">
|
||||
<h2 className="text-lg font-semibold mb-2">
|
||||
{arg} {index + 1}
|
||||
</h2>
|
||||
|
||||
<p className="line-clamp-6 text-sm text-gray-800">{prompt}</p>
|
||||
|
||||
<button
|
||||
onMouseDown={() => copyToClipboard(prompt, index)}
|
||||
className="flex mt-2 text-sm cursor-pointer items-center justify-center py-2 px-3 border border-background-200 bg-inverted text-text-900 rounded-full hover:bg-background-100 transition duration-200"
|
||||
>
|
||||
{copied == index ? (
|
||||
<>
|
||||
<FiCheck className="mr-2" size={16} />
|
||||
Copied!
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<CopyIcon className="mr-2" size={16} />
|
||||
Copy
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="w-[400px] bg-inverted mx-auto p-6 rounded-lg shadow-lg">
|
||||
<div className="flex flex-col gap-x-4">
|
||||
<PromptSection copied={copied} prompt={prompt1} index={0} />
|
||||
{prompt2 && (
|
||||
<>
|
||||
<Divider />
|
||||
<PromptSection copied={copied} prompt={prompt2} index={1} />
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
|
||||
DualPromptDisplay.displayName = "DualPromptDisplay";
|
||||
export default DualPromptDisplay;
|
||||
@@ -15,6 +15,7 @@ interface ModalProps {
|
||||
titleSize?: string;
|
||||
hideDividerForTitle?: boolean;
|
||||
noPadding?: boolean;
|
||||
hideCloseButton?: boolean;
|
||||
}
|
||||
|
||||
export function Modal({
|
||||
@@ -27,6 +28,7 @@ export function Modal({
|
||||
hideDividerForTitle,
|
||||
noPadding,
|
||||
icon,
|
||||
hideCloseButton,
|
||||
}: ModalProps) {
|
||||
const modalRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
@@ -60,7 +62,7 @@ export function Modal({
|
||||
${noPadding ? "" : "p-10"}
|
||||
${className || ""}`}
|
||||
>
|
||||
{onOutsideClick && (
|
||||
{onOutsideClick && !hideCloseButton && (
|
||||
<div className="absolute top-2 right-2">
|
||||
<button
|
||||
onClick={onOutsideClick}
|
||||
|
||||
@@ -145,6 +145,19 @@ export function ClientLayout({
|
||||
),
|
||||
link: "/admin/tools",
|
||||
},
|
||||
...(enableEnterprise
|
||||
? [
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<ClipboardIcon size={18} />
|
||||
<div className="ml-1">Standard Answers</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/standard-answer",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
|
||||
152
web/src/components/chat_display/CSVContent.tsx
Normal file
152
web/src/components/chat_display/CSVContent.tsx
Normal file
@@ -0,0 +1,152 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { FileDescriptor } from "@/app/chat/interfaces";
|
||||
import { WarningCircle } from "@phosphor-icons/react";
|
||||
|
||||
interface CSVData {
|
||||
[key: string]: string;
|
||||
}
|
||||
|
||||
export interface ToolDisplay {
|
||||
fileDescriptor: FileDescriptor;
|
||||
isLoading: boolean;
|
||||
fadeIn: boolean;
|
||||
}
|
||||
|
||||
export const CsvContent = ({
|
||||
fileDescriptor,
|
||||
isLoading,
|
||||
fadeIn,
|
||||
}: ToolDisplay) => {
|
||||
const [data, setData] = useState<CSVData[]>([]);
|
||||
const [headers, setHeaders] = useState<string[]>([]);
|
||||
useEffect(() => {
|
||||
fetchCSV(fileDescriptor.id);
|
||||
}, [fileDescriptor.id]);
|
||||
|
||||
const fetchCSV = async (id: string) => {
|
||||
try {
|
||||
const response = await fetch(`api/chat/file/${id}`);
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch CSV file");
|
||||
}
|
||||
|
||||
const contentLength = response.headers.get("Content-Length");
|
||||
const fileSizeInMB = contentLength
|
||||
? parseInt(contentLength) / (1024 * 1024)
|
||||
: 0;
|
||||
const MAX_FILE_SIZE_MB = 5;
|
||||
|
||||
if (fileSizeInMB > MAX_FILE_SIZE_MB) {
|
||||
throw new Error("File size exceeds the maximum limit of 5MB");
|
||||
}
|
||||
|
||||
const csvData = await response.text();
|
||||
const rows = csvData.trim().split("\n");
|
||||
const parsedHeaders = rows[0].split(",");
|
||||
setHeaders(parsedHeaders);
|
||||
|
||||
const parsedData: CSVData[] = rows.slice(1).map((row) => {
|
||||
const values = row.split(",");
|
||||
return parsedHeaders.reduce<CSVData>((obj, header, index) => {
|
||||
obj[header] = values[index];
|
||||
return obj;
|
||||
}, {});
|
||||
});
|
||||
setData(parsedData);
|
||||
} catch (error) {
|
||||
console.error("Error fetching CSV file:", error);
|
||||
setData([]);
|
||||
setHeaders([]);
|
||||
}
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<div className="flex items-center justify-center h-[300px]">
|
||||
<div className="animate-pulse w- flex space-x-4">
|
||||
<div className="rounded-full bg-background-200 h-10 w-10"></div>
|
||||
<div className="w-full flex-1 space-y-4 py-1">
|
||||
<div className="h-2 w-full bg-background-200 rounded"></div>
|
||||
<div className="w-full space-y-3">
|
||||
<div className="grid grid-cols-3 gap-4">
|
||||
<div className="h-2 bg-background-200 rounded col-span-2"></div>
|
||||
<div className="h-2 bg-background-200 rounded col-span-1"></div>
|
||||
</div>
|
||||
<div className="h-2 bg-background-200 rounded"></div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`transition-opacity transform relative duration-1000 ease-in-out ${fadeIn ? "opacity-100" : "opacity-0"}`}
|
||||
>
|
||||
<div className={`overflow-y-auto flex relative max-h-[400px]`}>
|
||||
<Table className="!relative !overflow-y-scroll">
|
||||
<TableHeader className="z-20 !sticky !top-0">
|
||||
<TableRow className="!bg-neutral-100">
|
||||
{headers.map((header, index) => (
|
||||
<TableHead className=" " key={index}>
|
||||
<p className="text-text-600 line-clamp-2 my-2 font-medium">
|
||||
{index === 0 ? "" : header}
|
||||
</p>
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
|
||||
<TableBody>
|
||||
{data.length > 0 ? (
|
||||
data.map((row, rowIndex) => (
|
||||
<TableRow key={rowIndex}>
|
||||
{headers.map((header, cellIndex) => (
|
||||
<TableCell
|
||||
className={`${
|
||||
cellIndex === 0 && "sticky left-0 !bg-neutral-100"
|
||||
}`}
|
||||
key={cellIndex}
|
||||
>
|
||||
{row[header]}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))
|
||||
) : (
|
||||
<TableRow>
|
||||
<TableCell
|
||||
colSpan={headers.length}
|
||||
className="text-center py-8"
|
||||
>
|
||||
<div className="flex flex-col items-center justify-center space-y-2">
|
||||
<WarningCircle className="w-8 h-8 text-error" />
|
||||
<p className="text-text-600 font-medium">
|
||||
{headers.length === 0
|
||||
? "Error loading CSV"
|
||||
: "No data available"}
|
||||
</p>
|
||||
<p className="text-text-400 text-sm">
|
||||
{headers.length === 0
|
||||
? "The CSV file may be too large or couldn't be loaded properly."
|
||||
: "The CSV file appears to be empty."}
|
||||
</p>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
143
web/src/components/chat_display/InteractiveToolResult.tsx
Normal file
143
web/src/components/chat_display/InteractiveToolResult.tsx
Normal file
@@ -0,0 +1,143 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
CustomTooltip,
|
||||
TooltipGroup,
|
||||
} from "@/components/tooltip/CustomTooltip";
|
||||
import {
|
||||
DexpandTwoIcon,
|
||||
DownloadCSVIcon,
|
||||
ExpandTwoIcon,
|
||||
OpenIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card";
|
||||
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { FileDescriptor } from "@/app/chat/interfaces";
|
||||
import { CsvContent, ToolDisplay } from "./CSVContent";
|
||||
|
||||
export default function ToolResult({
|
||||
csvFileDescriptor,
|
||||
close,
|
||||
}: {
|
||||
csvFileDescriptor: FileDescriptor;
|
||||
close: () => void;
|
||||
}) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const expand = () => setExpanded((prev) => !prev);
|
||||
|
||||
return (
|
||||
<>
|
||||
{expanded && (
|
||||
<Modal
|
||||
hideCloseButton
|
||||
onOutsideClick={() => setExpanded(false)}
|
||||
className="!max-w-5xl overflow-hidden rounded-lg animate-all ease-in !p-0"
|
||||
>
|
||||
<FileWrapper
|
||||
fileDescriptor={csvFileDescriptor}
|
||||
close={close}
|
||||
expanded={true}
|
||||
expand={expand}
|
||||
ContentComponent={CsvContent}
|
||||
/>
|
||||
</Modal>
|
||||
)}
|
||||
|
||||
<FileWrapper
|
||||
fileDescriptor={csvFileDescriptor}
|
||||
close={close}
|
||||
expanded={false}
|
||||
expand={expand}
|
||||
ContentComponent={CsvContent}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
interface FileWrapperProps {
|
||||
fileDescriptor: FileDescriptor;
|
||||
close: () => void;
|
||||
expanded: boolean;
|
||||
expand: () => void;
|
||||
ContentComponent: React.ComponentType<ToolDisplay>;
|
||||
}
|
||||
|
||||
export const FileWrapper = ({
|
||||
fileDescriptor,
|
||||
close,
|
||||
expanded,
|
||||
expand,
|
||||
ContentComponent,
|
||||
}: FileWrapperProps) => {
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [fadeIn, setFadeIn] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
// Simulate loading
|
||||
setTimeout(() => setIsLoading(false), 300);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
if (!isLoading) {
|
||||
setTimeout(() => setFadeIn(true), 50);
|
||||
} else {
|
||||
setFadeIn(false);
|
||||
}
|
||||
}, [isLoading]);
|
||||
|
||||
const downloadFile = () => {
|
||||
// Implement download logic here
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`${
|
||||
!expanded ? "w-message-sm" : "w-full"
|
||||
} !rounded !rounded-lg overflow-y-hidden w-full border border-border`}
|
||||
>
|
||||
<CardHeader className="w-full !py-0 !pb-4 border-b border-border border-b-neutral-200 !pt-4 !mb-0 z-[10] top-0">
|
||||
<div className="flex justify-between items-center">
|
||||
<CardTitle className="!my-auto text-ellipsis line-clamp-1 text-xl font-semibold text-text-700 pr-4 transition-colors duration-300">
|
||||
{fileDescriptor.name}
|
||||
</CardTitle>
|
||||
<div className="flex !my-auto">
|
||||
<TooltipGroup gap="gap-x-4">
|
||||
<CustomTooltip showTick line content="Download file">
|
||||
<button onClick={downloadFile}>
|
||||
<DownloadCSVIcon className="cursor-pointer transition-colors duration-300 hover:text-text-800 h-6 w-6 text-text-400" />
|
||||
</button>
|
||||
</CustomTooltip>
|
||||
<CustomTooltip
|
||||
line
|
||||
showTick
|
||||
content={expanded ? "Minimize" : "Full screen"}
|
||||
>
|
||||
<button onClick={expand}>
|
||||
{!expanded ? (
|
||||
<ExpandTwoIcon className="transition-colors duration-300 hover:text-text-800 h-6 w-6 cursor-pointer text-text-400" />
|
||||
) : (
|
||||
<DexpandTwoIcon className="transition-colors duration-300 hover:text-text-800 h-6 w-6 cursor-pointer text-text-400" />
|
||||
)}
|
||||
</button>
|
||||
</CustomTooltip>
|
||||
<CustomTooltip showTick line content="Hide">
|
||||
<button onClick={close}>
|
||||
<OpenIcon className="transition-colors duration-300 hover:text-text-800 h-6 w-6 cursor-pointer text-text-400" />
|
||||
</button>
|
||||
</CustomTooltip>
|
||||
</TooltipGroup>
|
||||
</div>
|
||||
</div>
|
||||
</CardHeader>
|
||||
<Card className="!rounded-none w-full max-h-[600px] !p-0 relative overflow-x-scroll overflow-y-scroll mx-auto">
|
||||
<CardContent className="!p-0">
|
||||
<ContentComponent
|
||||
fileDescriptor={fileDescriptor}
|
||||
isLoading={isLoading}
|
||||
fadeIn={fadeIn}
|
||||
/>
|
||||
</CardContent>
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
89
web/src/components/chat_display/graphs/BarChart.tsx
Normal file
89
web/src/components/chat_display/graphs/BarChart.tsx
Normal file
@@ -0,0 +1,89 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { BarChart, Bar, XAxis, YAxis, CartesianGrid, Tooltip } from "recharts";
|
||||
import { CardContent } from "@/components/ui/card";
|
||||
|
||||
interface BarDataPoint {
|
||||
x: number;
|
||||
y: number;
|
||||
width: number;
|
||||
color: string;
|
||||
}
|
||||
|
||||
interface BarPlotData {
|
||||
data: BarDataPoint[];
|
||||
title: string;
|
||||
xlabel: string;
|
||||
ylabel: string;
|
||||
xticks: number[];
|
||||
xticklabels: string[];
|
||||
}
|
||||
|
||||
export function BarChartDisplay({ fileId }: { fileId: string }) {
|
||||
const [barPlotData, setBarPlotData] = useState<BarPlotData | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
fetchPlotData(fileId);
|
||||
}, [fileId]);
|
||||
|
||||
const fetchPlotData = async (id: string) => {
|
||||
try {
|
||||
const response = await fetch(`api/chat/file/${id}`);
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch plot data");
|
||||
}
|
||||
const data: BarPlotData = await response.json();
|
||||
setBarPlotData(data);
|
||||
} catch (error) {
|
||||
console.error("Error fetching plot data:", error);
|
||||
}
|
||||
};
|
||||
|
||||
if (!barPlotData) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
console.log("IN THE FUNCTION");
|
||||
|
||||
// Transform data to match Recharts expected format
|
||||
const transformedData = barPlotData.data.map((point, index) => ({
|
||||
name: barPlotData.xticklabels[index] || point.x.toString(),
|
||||
value: point.y,
|
||||
}));
|
||||
|
||||
return (
|
||||
<>
|
||||
<h2>{barPlotData.title}</h2>
|
||||
<BarChart
|
||||
width={600}
|
||||
height={400}
|
||||
data={transformedData}
|
||||
margin={{
|
||||
top: 20,
|
||||
right: 30,
|
||||
left: 20,
|
||||
bottom: 5,
|
||||
}}
|
||||
>
|
||||
<CartesianGrid strokeDasharray="3 3" />
|
||||
<XAxis
|
||||
dataKey="name"
|
||||
label={{
|
||||
value: barPlotData.xlabel,
|
||||
position: "insideBottom",
|
||||
offset: -10,
|
||||
}}
|
||||
/>
|
||||
<YAxis
|
||||
label={{
|
||||
value: barPlotData.ylabel,
|
||||
angle: -90,
|
||||
position: "insideLeft",
|
||||
}}
|
||||
/>
|
||||
<Tooltip />
|
||||
<Bar dataKey="value" fill={barPlotData.data[0].color} />
|
||||
</BarChart>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default BarChartDisplay;
|
||||
56
web/src/components/chat_display/graphs/ImageDisplay.tsx
Normal file
56
web/src/components/chat_display/graphs/ImageDisplay.tsx
Normal file
@@ -0,0 +1,56 @@
|
||||
import { buildImgUrl } from "@/app/chat/files/images/utils";
|
||||
import React, { useState, useEffect } from "react";
|
||||
|
||||
export function ImageDisplay({ fileId }: { fileId: string }) {
|
||||
const [imageUrl, setImageUrl] = useState<string | null>(null);
|
||||
const [fullImageShowing, setFullImageShowing] = useState(false);
|
||||
|
||||
// useEffect(() => {
|
||||
// fetchImageUrl(fileId);
|
||||
// }, [fileId]);
|
||||
|
||||
// const fetchImageUrl = async (id: string) => {
|
||||
// try {
|
||||
// const response = await fetch(`api/chat/file/${id}`);
|
||||
// if (!response.ok) {
|
||||
// throw new Error('Failed to fetch image data');
|
||||
// }
|
||||
// const data = await response.json();
|
||||
// setImageUrl(data.imageUrl); // Assuming the API returns an object with an imageUrl field
|
||||
// } catch (error) {
|
||||
// console.error("Error fetching image data:", error);
|
||||
// }
|
||||
// };
|
||||
|
||||
// const buildImgUrl = (id: string) => {
|
||||
// // Implement your URL building logic here
|
||||
// return imageUrl || ''; // Return the fetched URL or an empty string if not available
|
||||
// };
|
||||
|
||||
return (
|
||||
// <div className="w-full h-full">
|
||||
<>
|
||||
<img
|
||||
className="w-full mx-auto object-cover object-center overflow-hidden rounded-lg w-full h-full transition-opacity duration-300 opacity-100"
|
||||
onClick={() => setFullImageShowing(true)}
|
||||
src={buildImgUrl(fileId)}
|
||||
alt="Fetched image"
|
||||
loading="lazy"
|
||||
/>
|
||||
{fullImageShowing && (
|
||||
<div
|
||||
className="fixed inset-0 bg-black bg-opacity-50 flex items-center justify-center z-50"
|
||||
onClick={() => setFullImageShowing(false)}
|
||||
>
|
||||
<img
|
||||
src={buildImgUrl(fileId)}
|
||||
alt="Full size image"
|
||||
className="max-w-90vw max-h-90vh object-contain"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
|
||||
// </div>
|
||||
);
|
||||
}
|
||||
261
web/src/components/chat_display/graphs/LineChartDisplay.tsx
Normal file
261
web/src/components/chat_display/graphs/LineChartDisplay.tsx
Normal file
@@ -0,0 +1,261 @@
|
||||
import React, { useEffect, useState } from "react";
|
||||
import { PickaxeIcon, TrendingUp } from "lucide-react";
|
||||
import { CartesianGrid, Line, LineChart, XAxis, YAxis } from "recharts";
|
||||
import {
|
||||
ChartConfig,
|
||||
ChartContainer,
|
||||
ChartTooltip,
|
||||
ChartTooltipContent,
|
||||
} from "@/components/ui/chart";
|
||||
import { CardFooter } from "@/components/ui/card";
|
||||
import {
|
||||
DexpandTwoIcon,
|
||||
DownloadCSVIcon,
|
||||
ExpandTwoIcon,
|
||||
OpenIcon,
|
||||
PaintingIcon,
|
||||
PaintingIconSkeleton,
|
||||
} from "@/components/icons/icons";
|
||||
import {
|
||||
CustomTooltip,
|
||||
TooltipGroup,
|
||||
} from "@/components/tooltip/CustomTooltip";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { ChartDataPoint, ChartType, Data, PlotData } from "./types";
|
||||
import { SelectionBackground } from "@phosphor-icons/react";
|
||||
|
||||
export function ModalChartWrapper({
|
||||
children,
|
||||
fileId,
|
||||
chartType,
|
||||
}: {
|
||||
children: JSX.Element;
|
||||
fileId: string;
|
||||
chartType: ChartType;
|
||||
}) {
|
||||
const [expanded, setExpanded] = useState(false);
|
||||
const expand = () => {
|
||||
setExpanded((expanded) => !expanded);
|
||||
};
|
||||
|
||||
return (
|
||||
<>
|
||||
{expanded ? (
|
||||
<Modal
|
||||
onOutsideClick={() => setExpanded(false)}
|
||||
className="animate-all ease-in !p-0"
|
||||
>
|
||||
<ChartWrapper
|
||||
chartType={chartType}
|
||||
expand={expand}
|
||||
expanded={expanded}
|
||||
fileId={fileId}
|
||||
>
|
||||
{children}
|
||||
</ChartWrapper>
|
||||
</Modal>
|
||||
) : (
|
||||
<ChartWrapper
|
||||
chartType={chartType}
|
||||
expand={expand}
|
||||
expanded={expanded}
|
||||
fileId={fileId}
|
||||
>
|
||||
{children}
|
||||
</ChartWrapper>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export function ChartWrapper({
|
||||
expanded,
|
||||
children,
|
||||
fileId,
|
||||
chartType,
|
||||
expand,
|
||||
}: {
|
||||
expanded: boolean;
|
||||
children: JSX.Element;
|
||||
chartType: ChartType;
|
||||
fileId: string;
|
||||
expand: () => void;
|
||||
}) {
|
||||
const [plotDataJson, setPlotDataJson] = useState<Data | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
fetchPlotData(fileId);
|
||||
}, [fileId]);
|
||||
|
||||
const fetchPlotData = async (id: string) => {
|
||||
if (chartType == "other") {
|
||||
setPlotDataJson({ title: "Uploaded Chart" });
|
||||
} else {
|
||||
try {
|
||||
const response = await fetch(`api/chat/file/${id}`);
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch plot data");
|
||||
}
|
||||
const data = await response.json();
|
||||
setPlotDataJson(data);
|
||||
} catch (error) {
|
||||
console.error("Error fetching plot data:", error);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const downloadFile = () => {
|
||||
if (!fileId) return;
|
||||
// Implement download functionality here
|
||||
};
|
||||
|
||||
if (!plotDataJson) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
return (
|
||||
<div className="bg-background-50 group rounded-lg shadow-md relative">
|
||||
<div className="relative p-4">
|
||||
<div className="relative flex pb-2 items-center justify-between">
|
||||
<h2 className="text-xl font-semibold mb-2">{plotDataJson.title}</h2>
|
||||
<div className="flex gap-x-2">
|
||||
<TooltipGroup>
|
||||
{chartType != "other" && (
|
||||
<CustomTooltip
|
||||
showTick
|
||||
line
|
||||
position="top"
|
||||
content="View Static file"
|
||||
>
|
||||
<button onClick={() => downloadFile()}>
|
||||
<PaintingIconSkeleton className="cursor-pointer transition-colors duration-300 hover:text-neutral-800 h-6 w-6 text-neutral-400" />
|
||||
</button>
|
||||
</CustomTooltip>
|
||||
)}
|
||||
|
||||
<CustomTooltip
|
||||
showTick
|
||||
line
|
||||
position="top"
|
||||
content="Download file"
|
||||
>
|
||||
<button onClick={() => downloadFile()}>
|
||||
<DownloadCSVIcon className="cursor-pointer ml-4 transition-colors duration-300 hover:text-neutral-800 h-6 w-6 text-neutral-400" />
|
||||
</button>
|
||||
</CustomTooltip>
|
||||
<CustomTooltip
|
||||
line
|
||||
position="top"
|
||||
content={expanded ? "Minimize" : "Full screen"}
|
||||
>
|
||||
<button onClick={() => expand()}>
|
||||
{!expanded ? (
|
||||
<ExpandTwoIcon className="transition-colors duration-300 ml-4 hover:text-neutral-800 h-6 w-6 cursor-pointer text-neutral-400" />
|
||||
) : (
|
||||
<DexpandTwoIcon className="transition-colors duration-300 ml-4 hover:text-neutral-800 h-6 w-6 cursor-pointer text-neutral-400" />
|
||||
)}
|
||||
</button>
|
||||
</CustomTooltip>
|
||||
</TooltipGroup>
|
||||
</div>
|
||||
</div>
|
||||
{children}
|
||||
{chartType === "other" && (
|
||||
<div className="absolute bottom-6 right-6 p-1.5 rounded flex gap-x-2 items-center border border-neutral-200 bg-neutral-50 opacity-0 transition-opacity duration-300 ease-in-out text-sm text-gray-500 group-hover:opacity-100">
|
||||
<SelectionBackground />
|
||||
Interactive charts of this type are not supported yet
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<CardFooter className="flex-col items-start gap-2 text-sm">
|
||||
<div className="flex gap-2 font-medium leading-none">
|
||||
Data from Matplotlib plot <TrendingUp className="h-4 w-4" />
|
||||
</div>
|
||||
</CardFooter>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function LineChartDisplay({ fileId }: { fileId: string }) {
|
||||
const [chartData, setChartData] = useState<ChartDataPoint[]>([]);
|
||||
const [chartConfig, setChartConfig] = useState<ChartConfig>({});
|
||||
|
||||
useEffect(() => {
|
||||
fetchPlotData(fileId);
|
||||
}, [fileId]);
|
||||
|
||||
const fetchPlotData = async (id: string) => {
|
||||
try {
|
||||
const response = await fetch(`api/chat/file/${id}`);
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch plot data");
|
||||
}
|
||||
const plotDataJson: PlotData = await response.json();
|
||||
console.log("plot data");
|
||||
console.log(plotDataJson);
|
||||
|
||||
const transformedData: ChartDataPoint[] = plotDataJson.data[0].x.map(
|
||||
(x, index) => ({
|
||||
x: x,
|
||||
y: plotDataJson.data[0].y[index],
|
||||
})
|
||||
);
|
||||
|
||||
setChartData(transformedData);
|
||||
|
||||
const config: ChartConfig = {
|
||||
y: {
|
||||
label: plotDataJson.data[0].label,
|
||||
color: plotDataJson.data[0].color,
|
||||
},
|
||||
};
|
||||
|
||||
setChartConfig(config);
|
||||
} catch (error) {
|
||||
console.error("Error fetching plot data:", error);
|
||||
}
|
||||
};
|
||||
console.log("chartData");
|
||||
console.log(chartData);
|
||||
|
||||
return (
|
||||
<div className="w-full h-full">
|
||||
<ChartContainer config={chartConfig}>
|
||||
<LineChart
|
||||
accessibilityLayer
|
||||
data={chartData}
|
||||
margin={{
|
||||
left: 12,
|
||||
right: 12,
|
||||
}}
|
||||
>
|
||||
<CartesianGrid vertical={false} />
|
||||
<XAxis
|
||||
dataKey="x"
|
||||
tickLine={false}
|
||||
axisLine={false}
|
||||
tickMargin={8}
|
||||
tickFormatter={(value: number) => value.toFixed(2)}
|
||||
/>
|
||||
<YAxis
|
||||
dataKey="y"
|
||||
tickLine={false}
|
||||
axisLine={false}
|
||||
tickMargin={8}
|
||||
tickFormatter={(value: number) => value.toFixed(2)}
|
||||
/>
|
||||
<ChartTooltip
|
||||
cursor={false}
|
||||
content={<ChartTooltipContent hideLabel />}
|
||||
/>
|
||||
<Line
|
||||
dataKey="y"
|
||||
type="natural"
|
||||
stroke={chartConfig.y?.color || "var(--chart-1)"}
|
||||
strokeWidth={2}
|
||||
dot={false}
|
||||
/>
|
||||
</LineChart>
|
||||
</ChartContainer>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
93
web/src/components/chat_display/graphs/PortalChart.tsx
Normal file
93
web/src/components/chat_display/graphs/PortalChart.tsx
Normal file
@@ -0,0 +1,93 @@
|
||||
import React, { useEffect, useState } from "react";
|
||||
import { TrendingUp } from "lucide-react";
|
||||
import {
|
||||
PolarAngleAxis,
|
||||
PolarGrid,
|
||||
PolarRadiusAxis,
|
||||
Radar,
|
||||
RadarChart,
|
||||
ResponsiveContainer,
|
||||
} from "recharts";
|
||||
import { CardContent } from "@/components/ui/card";
|
||||
import {
|
||||
ChartConfig,
|
||||
ChartContainer,
|
||||
ChartTooltip,
|
||||
ChartTooltipContent,
|
||||
} from "@/components/ui/chart";
|
||||
|
||||
import { PolarChartDataPoint, PolarPlotData } from "./types";
|
||||
|
||||
export function PolarChartDisplay({ fileId }: { fileId: string }) {
|
||||
const [chartData, setChartData] = useState<PolarChartDataPoint[]>([]);
|
||||
const [chartConfig, setChartConfig] = useState<ChartConfig>({});
|
||||
const [plotDataJson, setPlotDataJson] = useState<PolarPlotData | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
fetchPlotData(fileId);
|
||||
}, [fileId]);
|
||||
|
||||
const fetchPlotData = async (id: string) => {
|
||||
try {
|
||||
const response = await fetch(`api/chat/file/${id}`);
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to fetch plot data");
|
||||
}
|
||||
const data: PolarPlotData = await response.json();
|
||||
setPlotDataJson(data);
|
||||
|
||||
// Transform the JSON data to the format expected by the chart
|
||||
const transformedData: PolarChartDataPoint[] = data.data[0].theta.map(
|
||||
(angle, index) => ({
|
||||
angle: (angle * 180) / Math.PI, // Convert radians to degrees
|
||||
radius: data.data[0].r[index],
|
||||
})
|
||||
);
|
||||
|
||||
setChartData(transformedData);
|
||||
|
||||
// Create the chart config from the JSON data
|
||||
const config: ChartConfig = {
|
||||
y: {
|
||||
label: data.data[0].label,
|
||||
color: data.data[0].color,
|
||||
},
|
||||
};
|
||||
|
||||
setChartConfig(config);
|
||||
} catch (error) {
|
||||
console.error("Error fetching plot data:", error);
|
||||
}
|
||||
};
|
||||
|
||||
if (!plotDataJson) {
|
||||
return <div>Loading...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<CardContent>
|
||||
<ChartContainer config={chartConfig}>
|
||||
<ResponsiveContainer width="100%" height={400}>
|
||||
<RadarChart cx="50%" cy="50%" outerRadius="80%" data={chartData}>
|
||||
<PolarGrid />
|
||||
<PolarAngleAxis dataKey="angle" type="number" domain={[0, 360]} />
|
||||
<PolarRadiusAxis angle={30} domain={[0, plotDataJson.rmax]} />
|
||||
<ChartTooltip
|
||||
cursor={false}
|
||||
content={<ChartTooltipContent hideLabel />}
|
||||
/>
|
||||
<Radar
|
||||
name="Polar Plot"
|
||||
dataKey="radius"
|
||||
stroke={chartConfig.y?.color || "var(--chart-1)"}
|
||||
fill={chartConfig.y?.color || "var(--chart-1)"}
|
||||
fillOpacity={0.6}
|
||||
/>
|
||||
</RadarChart>
|
||||
</ResponsiveContainer>
|
||||
</ChartContainer>
|
||||
</CardContent>
|
||||
);
|
||||
}
|
||||
|
||||
export default PolarChartDisplay;
|
||||
37
web/src/components/chat_display/graphs/types.ts
Normal file
37
web/src/components/chat_display/graphs/types.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
export interface Data {
|
||||
title: string;
|
||||
}
|
||||
|
||||
export interface PlotData extends Data {
|
||||
data: Array<{
|
||||
x: number[];
|
||||
y: number[];
|
||||
label: string;
|
||||
color: string;
|
||||
}>;
|
||||
xlabel: string;
|
||||
ylabel: string;
|
||||
}
|
||||
|
||||
export interface ChartDataPoint {
|
||||
x: number;
|
||||
y: number;
|
||||
}
|
||||
export interface PolarPlotData extends Data {
|
||||
data: Array<{
|
||||
theta: number[];
|
||||
r: number[];
|
||||
label: string;
|
||||
color: string;
|
||||
}>;
|
||||
rmax: number;
|
||||
rticks: number[];
|
||||
rlabel_position: number;
|
||||
}
|
||||
|
||||
export interface PolarChartDataPoint {
|
||||
angle: number;
|
||||
radius: number;
|
||||
}
|
||||
|
||||
export type ChartType = "line" | "bar" | "radial" | "other";
|
||||
@@ -122,6 +122,102 @@ export const AssistantsIconSkeleton = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const OpenIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 14 14"
|
||||
>
|
||||
<path
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M7 13.5a9.26 9.26 0 0 0-5.61-2.95a1 1 0 0 1-.89-1V1.5A1 1 0 0 1 1.64.51A9.3 9.3 0 0 1 7 3.43zm0 0a9.26 9.26 0 0 1 5.61-2.95a1 1 0 0 0 .89-1V1.5a1 1 0 0 0-1.14-.99A9.3 9.3 0 0 0 7 3.43z"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const DexpandTwoIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 14 14"
|
||||
>
|
||||
<path
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="m.5 13.5l5-5m-4 0h4v4m8-12l-5 5m4 0h-4v-4"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const ExpandTwoIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 14 14"
|
||||
>
|
||||
<path
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="m8.5 5.5l5-5m-4 0h4v4m-8 4l-5 5m4 0h-4v-4"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const DownloadCSVIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
width="200"
|
||||
height="200"
|
||||
viewBox="0 0 14 14"
|
||||
>
|
||||
<path
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
stroke-linecap="round"
|
||||
stroke-linejoin="round"
|
||||
d="M.5 10.5v1a2 2 0 0 0 2 2h9a2 2 0 0 0 2-2v-1M4 6l3 3.5L10 6M7 9.5v-9"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const LightBulbIcon = ({
|
||||
size,
|
||||
className = defaultTailwindCSS,
|
||||
@@ -2811,3 +2907,40 @@ export const WindowsIcon = ({
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const ToolCallIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 19 15"
|
||||
fill="none"
|
||||
>
|
||||
<path
|
||||
d="M4.42 0.75H2.8625H2.75C1.64543 0.75 0.75 1.64543 0.75 2.75V11.65C0.75 12.7546 1.64543 13.65 2.75 13.65H2.8625C2.8625 13.65 2.8625 13.65 2.8625 13.65C2.8625 13.65 4.00751 13.65 4.42 13.65M13.98 13.65H15.5375H15.65C16.7546 13.65 17.65 12.7546 17.65 11.65V2.75C17.65 1.64543 16.7546 0.75 15.65 0.75H15.5375H13.98"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M5.55283 4.21963C5.25993 3.92674 4.78506 3.92674 4.49217 4.21963C4.19927 4.51252 4.19927 4.9874 4.49217 5.28029L6.36184 7.14996L4.49217 9.01963C4.19927 9.31252 4.19927 9.7874 4.49217 10.0803C4.78506 10.3732 5.25993 10.3732 5.55283 10.0803L7.95283 7.68029C8.24572 7.3874 8.24572 6.91252 7.95283 6.61963L5.55283 4.21963Z"
|
||||
fill="currentColor"
|
||||
stroke="currentColor"
|
||||
strokeWidth="0.2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M9.77753 8.75003C9.3357 8.75003 8.97753 9.10821 8.97753 9.55003C8.97753 9.99186 9.3357 10.35 9.77753 10.35H13.2775C13.7194 10.35 14.0775 9.99186 14.0775 9.55003C14.0775 9.10821 13.7194 8.75003 13.2775 8.75003H9.77753Z"
|
||||
fill="currentColor"
|
||||
stroke="currentColor"
|
||||
strokeWidth="0.1"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import {
|
||||
DanswerDocument,
|
||||
DocumentRelevance,
|
||||
Relevance,
|
||||
SearchDanswerDocument,
|
||||
} from "@/lib/search/interfaces";
|
||||
import { DocumentFeedbackBlock } from "./DocumentFeedbackBlock";
|
||||
@@ -12,11 +11,10 @@ import { PopupSpec } from "../admin/connectors/Popup";
|
||||
import { DocumentUpdatedAtBadge } from "./DocumentUpdatedAtBadge";
|
||||
import { SourceIcon } from "../SourceIcon";
|
||||
import { MetadataBadge } from "../MetadataBadge";
|
||||
import { BookIcon, CheckmarkIcon, LightBulbIcon, XIcon } from "../icons/icons";
|
||||
import { BookIcon, LightBulbIcon } from "../icons/icons";
|
||||
|
||||
import { FaStar } from "react-icons/fa";
|
||||
import { FiTag } from "react-icons/fi";
|
||||
import { DISABLE_LLM_DOC_RELEVANCE } from "@/lib/constants";
|
||||
import { SettingsContext } from "../settings/SettingsProvider";
|
||||
import { CustomTooltip, TooltipGroup } from "../tooltip/CustomTooltip";
|
||||
import { WarningCircle } from "@phosphor-icons/react";
|
||||
|
||||
@@ -19,9 +19,10 @@ const TooltipGroupContext = createContext<{
|
||||
hoverCountRef: { current: false },
|
||||
});
|
||||
|
||||
export const TooltipGroup: React.FC<{ children: React.ReactNode }> = ({
|
||||
children,
|
||||
}) => {
|
||||
export const TooltipGroup: React.FC<{
|
||||
children: React.ReactNode;
|
||||
gap?: string;
|
||||
}> = ({ children, gap = "" }) => {
|
||||
const [groupHovered, setGroupHovered] = useState(false);
|
||||
const hoverCountRef = useRef(false);
|
||||
|
||||
@@ -29,7 +30,7 @@ export const TooltipGroup: React.FC<{ children: React.ReactNode }> = ({
|
||||
<TooltipGroupContext.Provider
|
||||
value={{ groupHovered, setGroupHovered, hoverCountRef }}
|
||||
>
|
||||
<div className="inline-flex">{children}</div>
|
||||
<div className={`inline-flex ${gap}`}>{children}</div>
|
||||
</TooltipGroupContext.Provider>
|
||||
);
|
||||
};
|
||||
|
||||
56
web/src/components/ui/button.tsx
Normal file
56
web/src/components/ui/button.tsx
Normal file
@@ -0,0 +1,56 @@
|
||||
import * as React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cva, type VariantProps } from "class-variance-authority";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const buttonVariants = cva(
|
||||
"inline-flex items-center justify-center whitespace-nowrap rounded-md text-sm font-medium ring-offset-background transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:pointer-events-none disabled:opacity-50",
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
default: "bg-primary text-primary-foreground hover:bg-primary/90",
|
||||
destructive:
|
||||
"bg-destructive text-destructive-foreground hover:bg-destructive/90",
|
||||
outline:
|
||||
"border border-input bg-background hover:bg-accent hover:text-accent-foreground",
|
||||
secondary:
|
||||
"bg-secondary text-secondary-foreground hover:bg-secondary/80",
|
||||
ghost: "hover:bg-accent hover:text-accent-foreground",
|
||||
link: "text-primary underline-offset-4 hover:underline",
|
||||
},
|
||||
size: {
|
||||
default: "h-10 px-4 py-2",
|
||||
sm: "h-9 rounded-md px-3",
|
||||
lg: "h-11 rounded-md px-8",
|
||||
icon: "h-10 w-10",
|
||||
},
|
||||
},
|
||||
defaultVariants: {
|
||||
variant: "default",
|
||||
size: "default",
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
export interface ButtonProps
|
||||
extends React.ButtonHTMLAttributes<HTMLButtonElement>,
|
||||
VariantProps<typeof buttonVariants> {
|
||||
asChild?: boolean;
|
||||
}
|
||||
|
||||
const Button = React.forwardRef<HTMLButtonElement, ButtonProps>(
|
||||
({ className, variant, size, asChild = false, ...props }, ref) => {
|
||||
const Comp = asChild ? Slot : "button";
|
||||
return (
|
||||
<Comp
|
||||
className={cn(buttonVariants({ variant, size, className }))}
|
||||
ref={ref}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
Button.displayName = "Button";
|
||||
|
||||
export { Button, buttonVariants };
|
||||
86
web/src/components/ui/card.tsx
Normal file
86
web/src/components/ui/card.tsx
Normal file
@@ -0,0 +1,86 @@
|
||||
import * as React from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const Card = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"rounded-lg border bg-card text-card-foreground shadow-sm",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
Card.displayName = "Card";
|
||||
|
||||
const CardHeader = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("flex flex-col space-y-1.5 py-6 px-4", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
CardHeader.displayName = "CardHeader";
|
||||
|
||||
const CardTitle = React.forwardRef<
|
||||
HTMLParagraphElement,
|
||||
React.HTMLAttributes<HTMLHeadingElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<h3
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"text-2xl font-semibold leading-none tracking-tight",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
CardTitle.displayName = "CardTitle";
|
||||
|
||||
const CardDescription = React.forwardRef<
|
||||
HTMLParagraphElement,
|
||||
React.HTMLAttributes<HTMLParagraphElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<p
|
||||
ref={ref}
|
||||
className={cn("text-sm text-muted-foreground", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
CardDescription.displayName = "CardDescription";
|
||||
|
||||
const CardContent = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div ref={ref} className={cn("p-6 pt-0", className)} {...props} />
|
||||
));
|
||||
CardContent.displayName = "CardContent";
|
||||
|
||||
const CardFooter = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.HTMLAttributes<HTMLDivElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn("flex items-center p-6 pt-0", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
CardFooter.displayName = "CardFooter";
|
||||
|
||||
export {
|
||||
Card,
|
||||
CardHeader,
|
||||
CardFooter,
|
||||
CardTitle,
|
||||
CardDescription,
|
||||
CardContent,
|
||||
};
|
||||
365
web/src/components/ui/chart.tsx
Normal file
365
web/src/components/ui/chart.tsx
Normal file
@@ -0,0 +1,365 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import * as RechartsPrimitive from "recharts";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
// Format: { THEME_NAME: CSS_SELECTOR }
|
||||
const THEMES = { light: "", dark: ".dark" } as const;
|
||||
|
||||
export type ChartConfig = {
|
||||
[k in string]: {
|
||||
label?: React.ReactNode;
|
||||
icon?: React.ComponentType;
|
||||
} & (
|
||||
| { color?: string; theme?: never }
|
||||
| { color?: never; theme: Record<keyof typeof THEMES, string> }
|
||||
);
|
||||
};
|
||||
|
||||
type ChartContextProps = {
|
||||
config: ChartConfig;
|
||||
};
|
||||
|
||||
const ChartContext = React.createContext<ChartContextProps | null>(null);
|
||||
|
||||
function useChart() {
|
||||
const context = React.useContext(ChartContext);
|
||||
|
||||
if (!context) {
|
||||
throw new Error("useChart must be used within a <ChartContainer />");
|
||||
}
|
||||
|
||||
return context;
|
||||
}
|
||||
|
||||
const ChartContainer = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.ComponentProps<"div"> & {
|
||||
config: ChartConfig;
|
||||
children: React.ComponentProps<
|
||||
typeof RechartsPrimitive.ResponsiveContainer
|
||||
>["children"];
|
||||
}
|
||||
>(({ id, className, children, config, ...props }, ref) => {
|
||||
const uniqueId = React.useId();
|
||||
const chartId = `chart-${id || uniqueId.replace(/:/g, "")}`;
|
||||
|
||||
return (
|
||||
<ChartContext.Provider value={{ config }}>
|
||||
<div
|
||||
data-chart={chartId}
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"flex aspect-video justify-center text-xs [&_.recharts-cartesian-axis-tick_text]:fill-muted-foreground [&_.recharts-cartesian-grid_line[stroke='#ccc']]:stroke-border/50 [&_.recharts-curve.recharts-tooltip-cursor]:stroke-border [&_.recharts-dot[stroke='#fff']]:stroke-transparent [&_.recharts-layer]:outline-none [&_.recharts-polar-grid_[stroke='#ccc']]:stroke-border [&_.recharts-radial-bar-background-sector]:fill-muted [&_.recharts-rectangle.recharts-tooltip-cursor]:fill-muted [&_.recharts-reference-line_[stroke='#ccc']]:stroke-border [&_.recharts-sector[stroke='#fff']]:stroke-transparent [&_.recharts-sector]:outline-none [&_.recharts-surface]:outline-none",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<ChartStyle id={chartId} config={config} />
|
||||
<RechartsPrimitive.ResponsiveContainer>
|
||||
{children}
|
||||
</RechartsPrimitive.ResponsiveContainer>
|
||||
</div>
|
||||
</ChartContext.Provider>
|
||||
);
|
||||
});
|
||||
ChartContainer.displayName = "Chart";
|
||||
|
||||
const ChartStyle = ({ id, config }: { id: string; config: ChartConfig }) => {
|
||||
const colorConfig = Object.entries(config).filter(
|
||||
([_, config]) => config.theme || config.color
|
||||
);
|
||||
|
||||
if (!colorConfig.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<style
|
||||
dangerouslySetInnerHTML={{
|
||||
__html: Object.entries(THEMES)
|
||||
.map(
|
||||
([theme, prefix]) => `
|
||||
${prefix} [data-chart=${id}] {
|
||||
${colorConfig
|
||||
.map(([key, itemConfig]) => {
|
||||
const color =
|
||||
itemConfig.theme?.[theme as keyof typeof itemConfig.theme] ||
|
||||
itemConfig.color;
|
||||
return color ? ` --color-${key}: ${color};` : null;
|
||||
})
|
||||
.join("\n")}
|
||||
}
|
||||
`
|
||||
)
|
||||
.join("\n"),
|
||||
}}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
const ChartTooltip = RechartsPrimitive.Tooltip;
|
||||
|
||||
const ChartTooltipContent = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.ComponentProps<typeof RechartsPrimitive.Tooltip> &
|
||||
React.ComponentProps<"div"> & {
|
||||
hideLabel?: boolean;
|
||||
hideIndicator?: boolean;
|
||||
indicator?: "line" | "dot" | "dashed";
|
||||
nameKey?: string;
|
||||
labelKey?: string;
|
||||
}
|
||||
>(
|
||||
(
|
||||
{
|
||||
active,
|
||||
payload,
|
||||
className,
|
||||
indicator = "dot",
|
||||
hideLabel = false,
|
||||
hideIndicator = false,
|
||||
label,
|
||||
labelFormatter,
|
||||
labelClassName,
|
||||
formatter,
|
||||
color,
|
||||
nameKey,
|
||||
labelKey,
|
||||
},
|
||||
ref
|
||||
) => {
|
||||
const { config } = useChart();
|
||||
|
||||
const tooltipLabel = React.useMemo(() => {
|
||||
if (hideLabel || !payload?.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const [item] = payload;
|
||||
const key = `${labelKey || item.dataKey || item.name || "value"}`;
|
||||
const itemConfig = getPayloadConfigFromPayload(config, item, key);
|
||||
const value =
|
||||
!labelKey && typeof label === "string"
|
||||
? config[label as keyof typeof config]?.label || label
|
||||
: itemConfig?.label;
|
||||
|
||||
if (labelFormatter) {
|
||||
return (
|
||||
<div className={cn("font-medium", labelClassName)}>
|
||||
{labelFormatter(value, payload)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!value) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return <div className={cn("font-medium", labelClassName)}>{value}</div>;
|
||||
}, [
|
||||
label,
|
||||
labelFormatter,
|
||||
payload,
|
||||
hideLabel,
|
||||
labelClassName,
|
||||
config,
|
||||
labelKey,
|
||||
]);
|
||||
|
||||
if (!active || !payload?.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const nestLabel = payload.length === 1 && indicator !== "dot";
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"grid min-w-[8rem] items-start gap-1.5 rounded-lg border border-border/50 bg-background px-2.5 py-1.5 text-xs shadow-xl",
|
||||
className
|
||||
)}
|
||||
>
|
||||
{!nestLabel ? tooltipLabel : null}
|
||||
<div className="grid gap-1.5">
|
||||
{payload.map((item, index) => {
|
||||
const key = `${nameKey || item.name || item.dataKey || "value"}`;
|
||||
const itemConfig = getPayloadConfigFromPayload(config, item, key);
|
||||
const indicatorColor = color || item.payload.fill || item.color;
|
||||
|
||||
return (
|
||||
<div
|
||||
key={item.dataKey}
|
||||
className={cn(
|
||||
"flex w-full flex-wrap items-stretch gap-2 [&>svg]:h-2.5 [&>svg]:w-2.5 [&>svg]:text-muted-foreground",
|
||||
indicator === "dot" && "items-center"
|
||||
)}
|
||||
>
|
||||
{formatter && item?.value !== undefined && item.name ? (
|
||||
formatter(item.value, item.name, item, index, item.payload)
|
||||
) : (
|
||||
<>
|
||||
{itemConfig?.icon ? (
|
||||
<itemConfig.icon />
|
||||
) : (
|
||||
!hideIndicator && (
|
||||
<div
|
||||
className={cn(
|
||||
"shrink-0 rounded-[2px] border-[--color-border] bg-[--color-bg]",
|
||||
{
|
||||
"h-2.5 w-2.5": indicator === "dot",
|
||||
"w-1": indicator === "line",
|
||||
"w-0 border-[1.5px] border-dashed bg-transparent":
|
||||
indicator === "dashed",
|
||||
"my-0.5": nestLabel && indicator === "dashed",
|
||||
}
|
||||
)}
|
||||
style={
|
||||
{
|
||||
"--color-bg": indicatorColor,
|
||||
"--color-border": indicatorColor,
|
||||
} as React.CSSProperties
|
||||
}
|
||||
/>
|
||||
)
|
||||
)}
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-1 justify-between leading-none",
|
||||
nestLabel ? "items-end" : "items-center"
|
||||
)}
|
||||
>
|
||||
<div className="grid gap-1.5">
|
||||
{nestLabel ? tooltipLabel : null}
|
||||
<span className="text-muted-foreground">
|
||||
{itemConfig?.label || item.name}
|
||||
</span>
|
||||
</div>
|
||||
{item.value && (
|
||||
<span className="font-mono font-medium tabular-nums text-foreground">
|
||||
{item.value.toLocaleString()}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
ChartTooltipContent.displayName = "ChartTooltip";
|
||||
|
||||
const ChartLegend = RechartsPrimitive.Legend;
|
||||
|
||||
const ChartLegendContent = React.forwardRef<
|
||||
HTMLDivElement,
|
||||
React.ComponentProps<"div"> &
|
||||
Pick<RechartsPrimitive.LegendProps, "payload" | "verticalAlign"> & {
|
||||
hideIcon?: boolean;
|
||||
nameKey?: string;
|
||||
}
|
||||
>(
|
||||
(
|
||||
{ className, hideIcon = false, payload, verticalAlign = "bottom", nameKey },
|
||||
ref
|
||||
) => {
|
||||
const { config } = useChart();
|
||||
|
||||
if (!payload?.length) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"flex items-center justify-center gap-4",
|
||||
verticalAlign === "top" ? "pb-3" : "pt-3",
|
||||
className
|
||||
)}
|
||||
>
|
||||
{payload.map((item) => {
|
||||
const key = `${nameKey || item.dataKey || "value"}`;
|
||||
const itemConfig = getPayloadConfigFromPayload(config, item, key);
|
||||
|
||||
return (
|
||||
<div
|
||||
key={item.value}
|
||||
className={cn(
|
||||
"flex items-center gap-1.5 [&>svg]:h-3 [&>svg]:w-3 [&>svg]:text-muted-foreground"
|
||||
)}
|
||||
>
|
||||
{itemConfig?.icon && !hideIcon ? (
|
||||
<itemConfig.icon />
|
||||
) : (
|
||||
<div
|
||||
className="h-2 w-2 shrink-0 rounded-[2px]"
|
||||
style={{
|
||||
backgroundColor: item.color,
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
{itemConfig?.label}
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
);
|
||||
ChartLegendContent.displayName = "ChartLegend";
|
||||
|
||||
// Helper to extract item config from a payload.
|
||||
function getPayloadConfigFromPayload(
|
||||
config: ChartConfig,
|
||||
payload: unknown,
|
||||
key: string
|
||||
) {
|
||||
if (typeof payload !== "object" || payload === null) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
const payloadPayload =
|
||||
"payload" in payload &&
|
||||
typeof payload.payload === "object" &&
|
||||
payload.payload !== null
|
||||
? payload.payload
|
||||
: undefined;
|
||||
|
||||
let configLabelKey: string = key;
|
||||
|
||||
if (
|
||||
key in payload &&
|
||||
typeof payload[key as keyof typeof payload] === "string"
|
||||
) {
|
||||
configLabelKey = payload[key as keyof typeof payload] as string;
|
||||
} else if (
|
||||
payloadPayload &&
|
||||
key in payloadPayload &&
|
||||
typeof payloadPayload[key as keyof typeof payloadPayload] === "string"
|
||||
) {
|
||||
configLabelKey = payloadPayload[
|
||||
key as keyof typeof payloadPayload
|
||||
] as string;
|
||||
}
|
||||
|
||||
return configLabelKey in config
|
||||
? config[configLabelKey]
|
||||
: config[key as keyof typeof config];
|
||||
}
|
||||
|
||||
export {
|
||||
ChartContainer,
|
||||
ChartTooltip,
|
||||
ChartTooltipContent,
|
||||
ChartLegend,
|
||||
ChartLegendContent,
|
||||
ChartStyle,
|
||||
};
|
||||
25
web/src/components/ui/input.tsx
Normal file
25
web/src/components/ui/input.tsx
Normal file
@@ -0,0 +1,25 @@
|
||||
import * as React from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface InputProps
|
||||
extends React.InputHTMLAttributes<HTMLInputElement> {}
|
||||
|
||||
const Input = React.forwardRef<HTMLInputElement, InputProps>(
|
||||
({ className, type, ...props }, ref) => {
|
||||
return (
|
||||
<input
|
||||
type={type}
|
||||
className={cn(
|
||||
"flex h-10 w-full rounded-md border border-input bg-background px-3 py-2 text-sm ring-offset-background file:border-0 file:bg-transparent file:text-sm file:font-medium placeholder:text-muted-foreground focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-ring focus-visible:ring-offset-2 disabled:cursor-not-allowed disabled:opacity-50",
|
||||
className
|
||||
)}
|
||||
ref={ref}
|
||||
{...props}
|
||||
/>
|
||||
);
|
||||
}
|
||||
);
|
||||
Input.displayName = "Input";
|
||||
|
||||
export { Input };
|
||||
117
web/src/components/ui/table.tsx
Normal file
117
web/src/components/ui/table.tsx
Normal file
@@ -0,0 +1,117 @@
|
||||
import * as React from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const Table = React.forwardRef<
|
||||
HTMLTableElement,
|
||||
React.HTMLAttributes<HTMLTableElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<div className="relative w-full overflow-auto">
|
||||
<table
|
||||
ref={ref}
|
||||
className={cn("w-full caption-bottom text-sm", className)}
|
||||
{...props}
|
||||
/>
|
||||
</div>
|
||||
));
|
||||
Table.displayName = "Table";
|
||||
|
||||
const TableHeader = React.forwardRef<
|
||||
HTMLTableSectionElement,
|
||||
React.HTMLAttributes<HTMLTableSectionElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<thead ref={ref} className={cn("[&_tr]:border-b", className)} {...props} />
|
||||
));
|
||||
TableHeader.displayName = "TableHeader";
|
||||
|
||||
const TableBody = React.forwardRef<
|
||||
HTMLTableSectionElement,
|
||||
React.HTMLAttributes<HTMLTableSectionElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<tbody
|
||||
ref={ref}
|
||||
className={cn("[&_tr:last-child]:border-0", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableBody.displayName = "TableBody";
|
||||
|
||||
const TableFooter = React.forwardRef<
|
||||
HTMLTableSectionElement,
|
||||
React.HTMLAttributes<HTMLTableSectionElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<tfoot
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"border-t bg-muted/50 font-medium [&>tr]:last:border-b-0",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableFooter.displayName = "TableFooter";
|
||||
|
||||
const TableRow = React.forwardRef<
|
||||
HTMLTableRowElement,
|
||||
React.HTMLAttributes<HTMLTableRowElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<tr
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"border-b transition-colors hover:bg-muted/50 data-[state=selected]:bg-muted",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableRow.displayName = "TableRow";
|
||||
|
||||
const TableHead = React.forwardRef<
|
||||
HTMLTableCellElement,
|
||||
React.ThHTMLAttributes<HTMLTableCellElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<th
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"h-12 px-4 text-left align-middle font-medium text-muted-foreground [&:has([role=checkbox])]:pr-0",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableHead.displayName = "TableHead";
|
||||
|
||||
const TableCell = React.forwardRef<
|
||||
HTMLTableCellElement,
|
||||
React.TdHTMLAttributes<HTMLTableCellElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<td
|
||||
ref={ref}
|
||||
className={cn("p-4 align-middle [&:has([role=checkbox])]:pr-0", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableCell.displayName = "TableCell";
|
||||
|
||||
const TableCaption = React.forwardRef<
|
||||
HTMLTableCaptionElement,
|
||||
React.HTMLAttributes<HTMLTableCaptionElement>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<caption
|
||||
ref={ref}
|
||||
className={cn("mt-4 text-sm text-muted-foreground", className)}
|
||||
{...props}
|
||||
/>
|
||||
));
|
||||
TableCaption.displayName = "TableCaption";
|
||||
|
||||
export {
|
||||
Table,
|
||||
TableHeader,
|
||||
TableBody,
|
||||
TableFooter,
|
||||
TableHead,
|
||||
TableRow,
|
||||
TableCell,
|
||||
TableCaption,
|
||||
};
|
||||
@@ -30,6 +30,9 @@ export const SIDEBAR_WIDTH = `w-[350px]`;
|
||||
export const LOGOUT_DISABLED =
|
||||
process.env.NEXT_PUBLIC_DISABLE_LOGOUT?.toLowerCase() === "true";
|
||||
|
||||
export const DISABLED_CSV_DISPLAY =
|
||||
process.env.NEXT_PUBLIC_DISABLE_CSV_DISPLAY?.toLowerCase() === "true";
|
||||
|
||||
export const NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN =
|
||||
process.env.NEXT_PUBLIC_DEFAULT_SIDEBAR_OPEN?.toLowerCase() === "true";
|
||||
|
||||
|
||||
@@ -22,6 +22,8 @@ export interface AnswerPiecePacket {
|
||||
export enum StreamStopReason {
|
||||
CONTEXT_LENGTH = "CONTEXT_LENGTH",
|
||||
CANCELLED = "CANCELLED",
|
||||
FINISHED = "FINISHED",
|
||||
NEW_RESPONSE = "NEW_RESPONSE",
|
||||
}
|
||||
|
||||
export interface StreamStopInfo {
|
||||
|
||||
6
web/src/lib/utils.ts
Normal file
6
web/src/lib/utils.ts
Normal file
@@ -0,0 +1,6 @@
|
||||
import { type ClassValue, clsx } from "clsx";
|
||||
import { twMerge } from "tailwind-merge";
|
||||
|
||||
export function cn(...inputs: ClassValue[]) {
|
||||
return twMerge(clsx(inputs));
|
||||
}
|
||||
80
web/tailwind.config.ts
Normal file
80
web/tailwind.config.ts
Normal file
@@ -0,0 +1,80 @@
|
||||
import type { Config } from "tailwindcss";
|
||||
|
||||
const config = {
|
||||
darkMode: ["class"],
|
||||
content: [
|
||||
"./pages/**/*.{ts,tsx}",
|
||||
"./components/**/*.{ts,tsx}",
|
||||
"./app/**/*.{ts,tsx}",
|
||||
"./src/**/*.{ts,tsx}",
|
||||
],
|
||||
prefix: "",
|
||||
theme: {
|
||||
container: {
|
||||
center: true,
|
||||
padding: "2rem",
|
||||
screens: {
|
||||
"2xl": "1400px",
|
||||
},
|
||||
},
|
||||
extend: {
|
||||
colors: {
|
||||
border: "hsl(var(--border))",
|
||||
input: "hsl(var(--input))",
|
||||
ring: "hsl(var(--ring))",
|
||||
background: "hsl(var(--background))",
|
||||
foreground: "hsl(var(--foreground))",
|
||||
primary: {
|
||||
DEFAULT: "hsl(var(--primary))",
|
||||
foreground: "hsl(var(--primary-foreground))",
|
||||
},
|
||||
secondary: {
|
||||
DEFAULT: "hsl(var(--secondary))",
|
||||
foreground: "hsl(var(--secondary-foreground))",
|
||||
},
|
||||
destructive: {
|
||||
DEFAULT: "hsl(var(--destructive))",
|
||||
foreground: "hsl(var(--destructive-foreground))",
|
||||
},
|
||||
muted: {
|
||||
DEFAULT: "hsl(var(--muted))",
|
||||
foreground: "hsl(var(--muted-foreground))",
|
||||
},
|
||||
accent: {
|
||||
DEFAULT: "hsl(var(--accent))",
|
||||
foreground: "hsl(var(--accent-foreground))",
|
||||
},
|
||||
popover: {
|
||||
DEFAULT: "hsl(var(--popover))",
|
||||
foreground: "hsl(var(--popover-foreground))",
|
||||
},
|
||||
card: {
|
||||
DEFAULT: "hsl(var(--card))",
|
||||
foreground: "hsl(var(--card-foreground))",
|
||||
},
|
||||
},
|
||||
borderRadius: {
|
||||
lg: "var(--radius)",
|
||||
md: "calc(var(--radius) - 2px)",
|
||||
sm: "calc(var(--radius) - 4px)",
|
||||
},
|
||||
keyframes: {
|
||||
"accordion-down": {
|
||||
from: { height: "0" },
|
||||
to: { height: "var(--radix-accordion-content-height)" },
|
||||
},
|
||||
"accordion-up": {
|
||||
from: { height: "var(--radix-accordion-content-height)" },
|
||||
to: { height: "0" },
|
||||
},
|
||||
},
|
||||
animation: {
|
||||
"accordion-down": "accordion-down 0.2s ease-out",
|
||||
"accordion-up": "accordion-up 0.2s ease-out",
|
||||
},
|
||||
},
|
||||
},
|
||||
plugins: [require("tailwindcss-animate")],
|
||||
} satisfies Config;
|
||||
|
||||
export default config;
|
||||
Reference in New Issue
Block a user