mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 17:25:44 +00:00
Compare commits
1 Commits
bot_nit
...
additional
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6f9ed98a40 |
@@ -1,40 +0,0 @@
|
||||
"""non-nullbale slack bot id in channel config
|
||||
|
||||
Revision ID: f7a894b06d02
|
||||
Revises: 9f696734098f
|
||||
Create Date: 2024-12-06 12:55:42.845723
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7a894b06d02"
|
||||
down_revision = "9f696734098f"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Delete all rows with null slack_bot_id
|
||||
op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL")
|
||||
|
||||
# Make slack_bot_id non-nullable
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Make slack_bot_id nullable again
|
||||
op.alter_column(
|
||||
"slack_channel_config",
|
||||
"slack_bot_id",
|
||||
existing_type=sa.Integer(),
|
||||
nullable=True,
|
||||
)
|
||||
@@ -58,6 +58,7 @@ from danswer.auth.schemas import UserRole
|
||||
from danswer.auth.schemas import UserUpdate
|
||||
from danswer.configs.app_configs import AUTH_TYPE
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_VERIFICATION
|
||||
from danswer.configs.app_configs import EMAIL_FROM
|
||||
from danswer.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from danswer.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
@@ -131,12 +132,11 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
# For other auth types, if the user is authenticated it's assumed that
|
||||
# the user is already verified via the external IDP
|
||||
return False
|
||||
# all other auth types besides basic should require users to be
|
||||
# verified
|
||||
return not DISABLE_VERIFICATION and (
|
||||
AUTH_TYPE != AuthType.BASIC or REQUIRE_EMAIL_VERIFICATION
|
||||
)
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
|
||||
@@ -219,7 +219,7 @@ def connector_permission_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
lock = r.lock(
|
||||
DanswerRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT,
|
||||
|
||||
@@ -26,7 +26,7 @@ from danswer.db.models import Prompt
|
||||
from danswer.db.models import Tool
|
||||
from danswer.db.models import User
|
||||
from danswer.db.persona import get_prompts_by_ids
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import BaseTokenizer
|
||||
from danswer.server.query_and_chat.models import CreateChatMessageRequest
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
|
||||
@@ -1,14 +1,10 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -16,15 +12,8 @@ from danswer.context.search.enums import QueryFlow
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.context.search.enums import SearchType
|
||||
from danswer.context.search.models import RetrievalDocs
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.custom.base_tool_types import ToolResultType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import Prompt
|
||||
|
||||
|
||||
class LlmDoc(BaseModel):
|
||||
"""This contains the minimal set information for the LLM portion including citations"""
|
||||
@@ -221,109 +210,3 @@ AnswerQuestionStreamReturn = Iterator[AnswerQuestionPossibleReturn]
|
||||
class LLMMetricsContainer(BaseModel):
|
||||
prompt_tokens: int
|
||||
response_tokens: int
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
@@ -6,24 +6,19 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.answer import Answer
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import create_temporary_persona
|
||||
from danswer.chat.models import AllCitations
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import ChatDanswerBotResponse
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import CustomToolResponse
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import FileChatDisplay
|
||||
from danswer.chat.models import FinalUsedContextDocsResponse
|
||||
from danswer.chat.models import LLMRelevanceFilterResponse
|
||||
from danswer.chat.models import MessageResponseIDInfo
|
||||
from danswer.chat.models import MessageSpecificCitations
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import QADocsResponse
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
@@ -62,11 +57,16 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.utils import load_all_chat_files
|
||||
from danswer.file_store.utils import save_files
|
||||
from danswer.file_store.utils import save_files_from_urls
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import litellm_exception_to_error_msg
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.server.query_and_chat.models import ChatMessageDetail
|
||||
@@ -119,7 +119,6 @@ from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.long_term_log import LongTermLogger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -303,7 +302,6 @@ def stream_chat_message_objects(
|
||||
3. [always] A set of streamed LLM tokens or an error anywhere along the line if something fails
|
||||
4. [always] Details on the final AI response message that is created
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
use_existing_user_message = new_msg_req.use_existing_user_message
|
||||
existing_assistant_message_id = new_msg_req.existing_assistant_message_id
|
||||
|
||||
@@ -680,8 +678,7 @@ def stream_chat_message_objects(
|
||||
|
||||
reference_db_search_docs = None
|
||||
qa_docs_response = None
|
||||
# any files to associate with the AI message e.g. dall-e generated images
|
||||
ai_message_files = []
|
||||
ai_message_files = None # any files to associate with the AI message e.g. dall-e generated images
|
||||
dropped_indices = None
|
||||
tool_result = None
|
||||
|
||||
@@ -736,14 +733,8 @@ def stream_chat_message_objects(
|
||||
list[ImageGenerationResponse], packet.response
|
||||
)
|
||||
|
||||
file_ids = save_files(
|
||||
urls=[img.url for img in img_generation_response if img.url],
|
||||
base64_files=[
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data
|
||||
],
|
||||
tenant_id=tenant_id,
|
||||
file_ids = save_files_from_urls(
|
||||
[img.url for img in img_generation_response]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(id=str(file_id), type=ChatFileType.IMAGE)
|
||||
@@ -769,19 +760,15 @@ def stream_chat_message_objects(
|
||||
or custom_tool_response.response_type == "csv"
|
||||
):
|
||||
file_ids = custom_tool_response.tool_result.file_ids
|
||||
ai_message_files.extend(
|
||||
[
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=(
|
||||
ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV
|
||||
),
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
)
|
||||
ai_message_files = [
|
||||
FileDescriptor(
|
||||
id=str(file_id),
|
||||
type=ChatFileType.IMAGE
|
||||
if custom_tool_response.response_type == "image"
|
||||
else ChatFileType.CSV,
|
||||
)
|
||||
for file_id in file_ids
|
||||
]
|
||||
yield FileChatDisplay(
|
||||
file_ids=[str(file_id) for file_id in file_ids]
|
||||
)
|
||||
@@ -831,8 +818,7 @@ def stream_chat_message_objects(
|
||||
citations_list=answer.citations,
|
||||
db_docs=reference_db_search_docs,
|
||||
)
|
||||
if not answer.is_cancelled():
|
||||
yield AllCitations(citations=answer.citations)
|
||||
yield AllCitations(citations=answer.citations)
|
||||
|
||||
# Saving Gen AI answer and responding with message info
|
||||
tool_name_to_tool_id: dict[str, int] = {}
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: ChatMessage | PreviousMessage,
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
@@ -43,6 +43,9 @@ WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000"
|
||||
AUTH_TYPE = AuthType((os.environ.get("AUTH_TYPE") or AuthType.DISABLED.value).lower())
|
||||
DISABLE_AUTH = AUTH_TYPE == AuthType.DISABLED
|
||||
|
||||
# Necessary for cloud integration tests
|
||||
DISABLE_VERIFICATION = os.environ.get("DISABLE_VERIFICATION", "").lower() == "true"
|
||||
|
||||
# Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive
|
||||
# information. This provides an extra layer of security on top of Postgres access controls
|
||||
# and is available in Danswer EE
|
||||
@@ -81,14 +84,7 @@ OAUTH_CLIENT_SECRET = (
|
||||
or ""
|
||||
)
|
||||
|
||||
# for future OAuth connector support
|
||||
# OAUTH_CONFLUENCE_CLIENT_ID = os.environ.get("OAUTH_CONFLUENCE_CLIENT_ID", "")
|
||||
# OAUTH_CONFLUENCE_CLIENT_SECRET = os.environ.get("OAUTH_CONFLUENCE_CLIENT_SECRET", "")
|
||||
# OAUTH_JIRA_CLIENT_ID = os.environ.get("OAUTH_JIRA_CLIENT_ID", "")
|
||||
# OAUTH_JIRA_CLIENT_SECRET = os.environ.get("OAUTH_JIRA_CLIENT_SECRET", "")
|
||||
|
||||
USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "")
|
||||
|
||||
# for basic auth
|
||||
REQUIRE_EMAIL_VERIFICATION = (
|
||||
os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true"
|
||||
@@ -122,8 +118,6 @@ VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost"
|
||||
VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST
|
||||
VESPA_PORT = os.environ.get("VESPA_PORT") or "8081"
|
||||
VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071"
|
||||
# the number of times to try and connect to vespa on startup before giving up
|
||||
VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10)
|
||||
|
||||
VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "")
|
||||
|
||||
|
||||
@@ -2,8 +2,6 @@ import json
|
||||
import os
|
||||
|
||||
|
||||
IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get("IMAGE_GENERATION_OUTPUT_FORMAT", "url")
|
||||
|
||||
# if specified, will pass through request headers to the call to API calls made by custom tools
|
||||
CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None
|
||||
_CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get(
|
||||
|
||||
@@ -15,7 +15,6 @@ from danswer.connectors.confluence.utils import attachment_to_content
|
||||
from danswer.connectors.confluence.utils import build_confluence_document_id
|
||||
from danswer.connectors.confluence.utils import datetime_from_string
|
||||
from danswer.connectors.confluence.utils import extract_text_from_confluence_html
|
||||
from danswer.connectors.confluence.utils import validate_attachment_filetype
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
@@ -277,11 +276,9 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
):
|
||||
# If the page has restrictions, add them to the perm_sync_data
|
||||
# These will be used by doc_sync.py to sync permissions
|
||||
page_restrictions = page.get("restrictions")
|
||||
page_space_key = page.get("space", {}).get("key")
|
||||
page_perm_sync_data = {
|
||||
"restrictions": page_restrictions or {},
|
||||
"space_key": page_space_key,
|
||||
perm_sync_data = {
|
||||
"restrictions": page.get("restrictions", {}),
|
||||
"space_key": page.get("space", {}).get("key"),
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
@@ -291,7 +288,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
page["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=page_perm_sync_data,
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
attachment_cql = f"type=attachment and container='{page['id']}'"
|
||||
@@ -301,21 +298,6 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
expand=restrictions_expand,
|
||||
limit=_SLIM_DOC_BATCH_SIZE,
|
||||
):
|
||||
if not validate_attachment_filetype(attachment):
|
||||
continue
|
||||
attachment_restrictions = attachment.get("restrictions")
|
||||
if not attachment_restrictions:
|
||||
attachment_restrictions = page_restrictions
|
||||
|
||||
attachment_space_key = attachment.get("space", {}).get("key")
|
||||
if not attachment_space_key:
|
||||
attachment_space_key = page_space_key
|
||||
|
||||
attachment_perm_sync_data = {
|
||||
"restrictions": attachment_restrictions or {},
|
||||
"space_key": attachment_space_key,
|
||||
}
|
||||
|
||||
doc_metadata_list.append(
|
||||
SlimDocument(
|
||||
id=build_confluence_document_id(
|
||||
@@ -323,7 +305,7 @@ class ConfluenceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
attachment["_links"]["webui"],
|
||||
self.is_cloud,
|
||||
),
|
||||
perm_sync_data=attachment_perm_sync_data,
|
||||
perm_sync_data=perm_sync_data,
|
||||
)
|
||||
)
|
||||
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
|
||||
|
||||
@@ -177,23 +177,19 @@ def extract_text_from_confluence_html(
|
||||
return format_document_soup(soup)
|
||||
|
||||
|
||||
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
|
||||
return attachment["metadata"]["mediaType"] not in [
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if attachment["metadata"]["mediaType"] in [
|
||||
"image/jpeg",
|
||||
"image/png",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
"video/mp4",
|
||||
"video/quicktime",
|
||||
]
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
]:
|
||||
return None
|
||||
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
@@ -249,7 +245,7 @@ def build_confluence_document_id(
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
def extract_referenced_attachment_names(page_text: str) -> list[str]:
|
||||
"""Parse a Confluence html page to generate a list of current
|
||||
attachments in use
|
||||
|
||||
|
||||
@@ -5,11 +5,7 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.chat.prune_and_merge import _merge_sections
|
||||
from danswer.chat.prune_and_merge import ChunkRange
|
||||
from danswer.chat.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from danswer.context.search.enums import LLMEvaluationType
|
||||
from danswer.context.search.enums import QueryFlow
|
||||
@@ -31,6 +27,10 @@ from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.interfaces import VespaChunkRequest
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prune_and_merge import _merge_sections
|
||||
from danswer.llm.answering.prune_and_merge import ChunkRange
|
||||
from danswer.llm.answering.prune_and_merge import merge_chunk_intervals
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.secondary_llm_flows.agentic_evaluation import evaluate_inference_section
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -248,6 +248,7 @@ def create_credential(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return credential
|
||||
|
||||
|
||||
|
||||
@@ -1490,9 +1490,7 @@ class SlackChannelConfig(Base):
|
||||
__tablename__ = "slack_channel_config"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
slack_bot_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("slack_bot.id"), nullable=False
|
||||
)
|
||||
slack_bot_id: Mapped[int] = mapped_column(ForeignKey("slack_bot.id"), nullable=True)
|
||||
persona_id: Mapped[int | None] = mapped_column(
|
||||
ForeignKey("persona.id"), nullable=True
|
||||
)
|
||||
|
||||
@@ -4,8 +4,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
# Not to be confused with the UUID generated for this chunk which is called documentid by default
|
||||
field document_id type string {
|
||||
indexing: summary | attribute
|
||||
attribute: fast-search
|
||||
rank: filter
|
||||
}
|
||||
field chunk_id type int {
|
||||
indexing: summary | attribute
|
||||
|
||||
@@ -6,7 +6,6 @@ import zipfile
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from email.parser import Parser as EmailParser
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import Dict
|
||||
@@ -16,17 +15,13 @@ import chardet
|
||||
import docx # type: ignore
|
||||
import openpyxl # type: ignore
|
||||
import pptx # type: ignore
|
||||
from docx import Document
|
||||
from fastapi import UploadFile
|
||||
from pypdf import PdfReader
|
||||
from pypdf.errors import PdfStreamError
|
||||
|
||||
from danswer.configs.constants import DANSWER_METADATA_FILENAME
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.file_processing.html_utils import parse_html_page_basic
|
||||
from danswer.file_processing.unstructured import get_unstructured_api_key
|
||||
from danswer.file_processing.unstructured import unstructured_to_text
|
||||
from danswer.file_store.file_store import FileStore
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -380,35 +375,3 @@ def extract_file_text(
|
||||
) from e
|
||||
logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}")
|
||||
return ""
|
||||
|
||||
|
||||
def convert_docx_to_txt(
|
||||
file: UploadFile, file_store: FileStore, file_path: str
|
||||
) -> None:
|
||||
file.file.seek(0)
|
||||
docx_content = file.file.read()
|
||||
doc = Document(BytesIO(docx_content))
|
||||
|
||||
# Extract text from the document
|
||||
full_text = []
|
||||
for para in doc.paragraphs:
|
||||
full_text.append(para.text)
|
||||
|
||||
# Join the extracted text
|
||||
text_content = "\n".join(full_text)
|
||||
|
||||
txt_file_path = docx_to_txt_filename(file_path)
|
||||
file_store.save_file(
|
||||
file_name=txt_file_path,
|
||||
content=BytesIO(text_content.encode("utf-8")),
|
||||
display_name=file.filename,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
file_type="text/plain",
|
||||
)
|
||||
|
||||
|
||||
def docx_to_txt_filename(file_path: str) -> str:
|
||||
"""
|
||||
Convert a .docx file path to its corresponding .txt file path.
|
||||
"""
|
||||
return file_path.rsplit(".", 1)[0] + ".txt"
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import base64
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
@@ -13,8 +13,8 @@ from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.utils.b64 import get_image_type
|
||||
from danswer.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def load_chat_file(
|
||||
@@ -75,58 +75,11 @@ def save_file_from_url(url: str, tenant_id: str) -> str:
|
||||
return unique_id
|
||||
|
||||
|
||||
def save_file_from_base64(base64_string: str, tenant_id: str) -> str:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
unique_id = str(uuid4())
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_store.save_file(
|
||||
file_name=unique_id,
|
||||
content=BytesIO(base64.b64decode(base64_string)),
|
||||
display_name="GeneratedImage",
|
||||
file_origin=FileOrigin.CHAT_IMAGE_GEN,
|
||||
file_type=get_image_type(base64_string),
|
||||
)
|
||||
return unique_id
|
||||
def save_files_from_urls(urls: list[str]) -> list[str]:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
|
||||
|
||||
def save_file(
|
||||
tenant_id: str,
|
||||
url: str | None = None,
|
||||
base64_data: str | None = None,
|
||||
) -> str:
|
||||
"""Save a file from either a URL or base64 encoded string.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to save the file under
|
||||
url: URL to download file from
|
||||
base64_data: Base64 encoded file data
|
||||
|
||||
Returns:
|
||||
The unique ID of the saved file
|
||||
|
||||
Raises:
|
||||
ValueError: If neither url nor base64_data is provided, or if both are provided
|
||||
"""
|
||||
if url is not None and base64_data is not None:
|
||||
raise ValueError("Cannot specify both url and base64_data")
|
||||
|
||||
if url is not None:
|
||||
return save_file_from_url(url, tenant_id)
|
||||
elif base64_data is not None:
|
||||
return save_file_from_base64(base64_data, tenant_id)
|
||||
else:
|
||||
raise ValueError("Must specify either url or base64_data")
|
||||
|
||||
|
||||
def save_files(urls: list[str], base64_files: list[str], tenant_id: str) -> list[str]:
|
||||
# NOTE: be explicit about typing so that if we change things, we get notified
|
||||
funcs: list[
|
||||
tuple[
|
||||
Callable[[str, str | None, str | None], str],
|
||||
tuple[str, str | None, str | None],
|
||||
]
|
||||
] = [(save_file, (tenant_id, url, None)) for url in urls] + [
|
||||
(save_file, (tenant_id, None, base64_file)) for base64_file in base64_files
|
||||
funcs: list[tuple[Callable[..., Any], tuple[Any, ...]]] = [
|
||||
(save_file_from_url, (url, tenant_id)) for url in urls
|
||||
]
|
||||
|
||||
# Must pass in tenant_id here, since this is called by multithreading
|
||||
return run_functions_tuples_in_parallel(funcs)
|
||||
|
||||
@@ -6,27 +6,27 @@ from langchain.schema.messages import BaseMessage
|
||||
from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.chat.models import AnswerQuestionPossibleReturn
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.build import default_build_system_message
|
||||
from danswer.chat.prompt_builder.build import default_build_user_message
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMResponseHandlerManager
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import default_build_system_message
|
||||
from danswer.llm.answering.prompts.build import default_build_user_message
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
CitationResponseHandler,
|
||||
)
|
||||
from danswer.chat.stream_processing.answer_response_handler import (
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
DummyAnswerResponseHandler,
|
||||
)
|
||||
from danswer.chat.stream_processing.utils import map_document_id_order
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.stream_processing.utils import map_document_id_order
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolResponse
|
||||
@@ -1,22 +1,58 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.chat.stream_processing.answer_response_handler import AnswerResponseHandler
|
||||
from danswer.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.stream_processing.answer_response_handler import (
|
||||
AnswerResponseHandler,
|
||||
)
|
||||
from danswer.llm.answering.tool.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
ResponsePart = (
|
||||
DanswerAnswerPiece
|
||||
| CitationInfo
|
||||
| ToolCallKickoff
|
||||
| ToolResponse
|
||||
| ToolCallFinalResult
|
||||
| StreamStopInfo
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class LLMResponseHandlerManager:
|
||||
def __init__(
|
||||
self,
|
||||
tool_handler: ToolResponseHandler,
|
||||
answer_handler: AnswerResponseHandler,
|
||||
tool_handler: "ToolResponseHandler",
|
||||
answer_handler: "AnswerResponseHandler",
|
||||
is_cancelled: Callable[[], bool],
|
||||
):
|
||||
self.tool_handler = tool_handler
|
||||
163
backend/danswer/llm/answering/models.py
Normal file
163
backend/danswer/llm/answering/models.py
Normal file
@@ -0,0 +1,163 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from pydantic import Field
|
||||
from pydantic import model_validator
|
||||
|
||||
from danswer.chat.models import AnswerQuestionStreamReturn
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.db.models import Prompt
|
||||
|
||||
|
||||
StreamProcessor = Callable[[Iterator[str]], AnswerQuestionStreamReturn]
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
|
||||
) -> "PreviousMessage":
|
||||
message_file_ids = (
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
)
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
files=[
|
||||
file
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
elif self.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
else:
|
||||
return SystemMessage(content=content)
|
||||
|
||||
|
||||
class DocumentPruningConfig(BaseModel):
|
||||
max_chunks: int | None = None
|
||||
max_window_percentage: float | None = None
|
||||
max_tokens: int | None = None
|
||||
# different pruning behavior is expected when the
|
||||
# user manually selects documents they want to chat with
|
||||
# e.g. we don't want to truncate each document to be no more
|
||||
# than one chunk long
|
||||
is_manually_selected_docs: bool = False
|
||||
# If user specifies to include additional context Chunks for each match, then different pruning
|
||||
# is used. As many Sections as possible are included, and the last Section is truncated
|
||||
# If this is false, all of the Sections are truncated if they are longer than the expected Chunk size.
|
||||
# Sections are often expected to be longer than the maximum Chunk size but Chunks should not be.
|
||||
use_sections: bool = True
|
||||
# If using tools, then we need to consider the tool length
|
||||
tool_num_tokens: int = 0
|
||||
# If using a tool message to represent the docs, then we have to JSON serialize
|
||||
# the document content, which adds to the token count.
|
||||
using_tool_message: bool = False
|
||||
|
||||
|
||||
class ContextualPruningConfig(DocumentPruningConfig):
|
||||
num_chunk_multiple: int
|
||||
|
||||
@classmethod
|
||||
def from_doc_pruning_config(
|
||||
cls, num_chunk_multiple: int, doc_pruning_config: DocumentPruningConfig
|
||||
) -> "ContextualPruningConfig":
|
||||
return cls(num_chunk_multiple=num_chunk_multiple, **doc_pruning_config.dict())
|
||||
|
||||
|
||||
class CitationConfig(BaseModel):
|
||||
all_docs_useful: bool = False
|
||||
|
||||
|
||||
class QuotesConfig(BaseModel):
|
||||
pass
|
||||
|
||||
|
||||
class AnswerStyleConfig(BaseModel):
|
||||
citation_config: CitationConfig | None = None
|
||||
quotes_config: QuotesConfig | None = None
|
||||
document_pruning_config: DocumentPruningConfig = Field(
|
||||
default_factory=DocumentPruningConfig
|
||||
)
|
||||
# forces the LLM to return a structured response, see
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
# right now, only used by the simple chat API
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_quotes_and_citation(self) -> "AnswerStyleConfig":
|
||||
if self.citation_config is None and self.quotes_config is None:
|
||||
raise ValueError(
|
||||
"One of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
if self.citation_config is not None and self.quotes_config is not None:
|
||||
raise ValueError(
|
||||
"Only one of `citation_config` or `quotes_config` must be provided"
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
|
||||
class PromptConfig(BaseModel):
|
||||
"""Final representation of the Prompt configuration passed
|
||||
into the `Answer` object."""
|
||||
|
||||
system_prompt: str
|
||||
task_prompt: str
|
||||
datetime_aware: bool
|
||||
include_citations: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, model: "Prompt", prompt_override: PromptOverride | None = None
|
||||
) -> "PromptConfig":
|
||||
override_system_prompt = (
|
||||
prompt_override.system_prompt if prompt_override else None
|
||||
)
|
||||
override_task_prompt = prompt_override.task_prompt if prompt_override else None
|
||||
|
||||
return cls(
|
||||
system_prompt=override_system_prompt or model.system_prompt,
|
||||
task_prompt=override_task_prompt or model.task_prompt,
|
||||
datetime_aware=model.datetime_aware,
|
||||
include_citations=model.include_citations,
|
||||
)
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
@@ -4,26 +4,20 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from pydantic.v1 import BaseModel as BaseModel__v1
|
||||
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.chat.prompt_builder.utils import translate_history_to_basemessages
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_message_tokens
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.llm.utils import translate_history_to_basemessages
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.prompts.chat_prompts import CHAT_USER_CONTEXT_FREE_PROMPT
|
||||
from danswer.prompts.prompt_utils import add_date_time_to_prompt
|
||||
from danswer.prompts.prompt_utils import drop_messages_history_overflow
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
|
||||
def default_build_system_message(
|
||||
@@ -145,15 +139,3 @@ class AnswerPromptBuilder:
|
||||
return drop_messages_history_overflow(
|
||||
final_messages_with_tokens, self.max_tokens
|
||||
)
|
||||
|
||||
|
||||
class LLMCall(BaseModel__v1):
|
||||
prompt_builder: AnswerPromptBuilder
|
||||
tools: list[Tool]
|
||||
force_use_tool: ForceUseTool
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call_info: list[ToolCallKickoff | ToolResponse | ToolCallFinalResult]
|
||||
using_tool_calling_llm: bool
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
@@ -2,12 +2,12 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.model_configs import GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
@@ -1,10 +1,10 @@
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.search_settings import get_multilingual_expansion
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.prompts.direct_qa_prompts import CONTEXT_BLOCK
|
||||
from danswer.prompts.direct_qa_prompts import HISTORY_BLOCK
|
||||
20
backend/danswer/llm/answering/prompts/utils.py
Normal file
20
backend/danswer/llm/answering/prompts/utils.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT
|
||||
from danswer.prompts.direct_qa_prompts import PARAMATERIZED_PROMPT_WITHOUT_CONTEXT
|
||||
|
||||
|
||||
def build_dummy_prompt(
|
||||
system_prompt: str, task_prompt: str, retrieval_disabled: bool
|
||||
) -> str:
|
||||
if retrieval_disabled:
|
||||
return PARAMATERIZED_PROMPT_WITHOUT_CONTEXT.format(
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
|
||||
return PARAMATERIZED_PROMPT.format(
|
||||
context_docs_str="<CONTEXT_DOCS>",
|
||||
user_query="<USER_QUERY>",
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
).strip()
|
||||
@@ -5,16 +5,16 @@ from typing import TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.models import ContextualPruningConfig
|
||||
from danswer.chat.models import (
|
||||
LlmDoc,
|
||||
)
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_document_tokens
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_document_tokens
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
from danswer.natural_language_processing.utils import tokenizer_trim_content
|
||||
@@ -3,11 +3,13 @@ from collections.abc import Generator
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.llm_response_handler import ResponsePart
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -4,8 +4,8 @@ from collections.abc import Generator
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.chat_configs import STOP_STREAM_PAT
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.prompts.constants import TRIPLE_BACKTICK
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
@@ -4,8 +4,8 @@ from langchain_core.messages import AIMessageChunk
|
||||
from langchain_core.messages import BaseMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
|
||||
from danswer.chat.models import ResponsePart
|
||||
from danswer.chat.prompt_builder.build import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.llm_response_handler import ResponsePart
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.message import build_tool_message
|
||||
@@ -1,59 +0,0 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from langchain.schema.messages import AIMessage
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.db.models import ChatMessage
|
||||
|
||||
|
||||
class PreviousMessage(BaseModel):
|
||||
"""Simplified version of `ChatMessage`"""
|
||||
|
||||
message: str
|
||||
token_count: int
|
||||
message_type: MessageType
|
||||
files: list[InMemoryChatFile]
|
||||
tool_call: ToolCallFinalResult | None
|
||||
|
||||
@classmethod
|
||||
def from_chat_message(
|
||||
cls, chat_message: "ChatMessage", available_files: list[InMemoryChatFile]
|
||||
) -> "PreviousMessage":
|
||||
message_file_ids = (
|
||||
[file["id"] for file in chat_message.files] if chat_message.files else []
|
||||
)
|
||||
return cls(
|
||||
message=chat_message.message,
|
||||
token_count=chat_message.token_count,
|
||||
message_type=chat_message.message_type,
|
||||
files=[
|
||||
file
|
||||
for file in available_files
|
||||
if str(file.file_id) in message_file_ids
|
||||
],
|
||||
tool_call=ToolCallFinalResult(
|
||||
tool_name=chat_message.tool_call.tool_name,
|
||||
tool_args=chat_message.tool_call.tool_arguments,
|
||||
tool_result=chat_message.tool_call.tool_result,
|
||||
)
|
||||
if chat_message.tool_call
|
||||
else None,
|
||||
)
|
||||
|
||||
def to_langchain_msg(self) -> BaseMessage:
|
||||
content = build_content_with_imgs(self.message, self.files)
|
||||
if self.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
elif self.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
else:
|
||||
return SystemMessage(content=content)
|
||||
@@ -5,6 +5,8 @@ from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
import litellm # type: ignore
|
||||
import pandas as pd
|
||||
@@ -34,15 +36,17 @@ from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
from danswer.utils.b64 import get_image_type
|
||||
from danswer.utils.b64 import get_image_type_from_bytes
|
||||
from danswer.utils.logger import setup_logger
|
||||
from shared_configs.configs import LOG_LEVEL
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -100,6 +104,39 @@ def litellm_exception_to_error_msg(
|
||||
return error_msg
|
||||
|
||||
|
||||
def translate_danswer_msg_to_langchain(
|
||||
msg: Union[ChatMessage, "PreviousMessage"],
|
||||
) -> BaseMessage:
|
||||
files: list[InMemoryChatFile] = []
|
||||
|
||||
# If the message is a `ChatMessage`, it doesn't have the downloaded files
|
||||
# attached. Just ignore them for now.
|
||||
if not isinstance(msg, ChatMessage):
|
||||
files = msg.files
|
||||
content = build_content_with_imgs(msg.message, files, message_type=msg.message_type)
|
||||
|
||||
if msg.message_type == MessageType.SYSTEM:
|
||||
raise ValueError("System messages are not currently part of history")
|
||||
if msg.message_type == MessageType.ASSISTANT:
|
||||
return AIMessage(content=content)
|
||||
if msg.message_type == MessageType.USER:
|
||||
return HumanMessage(content=content)
|
||||
|
||||
raise ValueError(f"New message type {msg.message_type} not handled")
|
||||
|
||||
|
||||
def translate_history_to_basemessages(
|
||||
history: list[ChatMessage] | list["PreviousMessage"],
|
||||
) -> tuple[list[BaseMessage], list[int]]:
|
||||
history_basemessages = [
|
||||
translate_danswer_msg_to_langchain(msg)
|
||||
for msg in history
|
||||
if msg.token_count != 0
|
||||
]
|
||||
history_token_counts = [msg.token_count for msg in history if msg.token_count != 0]
|
||||
return history_basemessages, history_token_counts
|
||||
|
||||
|
||||
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
|
||||
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
|
||||
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
|
||||
@@ -153,7 +190,6 @@ def build_content_with_imgs(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
message_type: MessageType = MessageType.USER,
|
||||
) -> str | list[str | dict[str, Any]]: # matching Langchain's BaseMessage content type
|
||||
files = files or []
|
||||
@@ -166,7 +202,6 @@ def build_content_with_imgs(
|
||||
)
|
||||
|
||||
img_urls = img_urls or []
|
||||
b64_imgs = b64_imgs or []
|
||||
|
||||
message_main_content = _build_content(message, files)
|
||||
|
||||
@@ -185,22 +220,11 @@ def build_content_with_imgs(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": (
|
||||
f"data:{get_image_type_from_bytes(file.content)};"
|
||||
f"base64,{file.to_base64()}"
|
||||
),
|
||||
"url": f"data:image/jpeg;base64,{file.to_base64()}",
|
||||
},
|
||||
}
|
||||
for file in img_files
|
||||
]
|
||||
+ [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": f"data:{get_image_type(b64_img)};base64,{b64_img}",
|
||||
},
|
||||
}
|
||||
for b64_img in b64_imgs
|
||||
for file in files
|
||||
if file.file_type == "image"
|
||||
]
|
||||
+ [
|
||||
{
|
||||
|
||||
@@ -105,6 +105,7 @@ from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
|
||||
@@ -5,11 +5,11 @@ from typing import cast
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.chat_configs import LANGUAGE_HINT
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.db.models import Prompt
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.prompts.chat_prompts import ADDITIONAL_INFO
|
||||
from danswer.prompts.chat_prompts import CITATION_REMINDER
|
||||
from danswer.prompts.constants import CODE_BLOCK_PAT
|
||||
|
||||
@@ -3,14 +3,14 @@ from langchain.schema import HumanMessage
|
||||
from langchain.schema import SystemMessage
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.prompt_builder.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.llm.utils import translate_danswer_msg_to_langchain
|
||||
from danswer.prompts.chat_prompts import AGGRESSIVE_SEARCH_TEMPLATE
|
||||
from danswer.prompts.chat_prompts import NO_SEARCH
|
||||
from danswer.prompts.chat_prompts import REQUIRE_SEARCH_HINT
|
||||
|
||||
@@ -4,10 +4,10 @@ from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_REPHRASE
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import HISTORY_QUERY_REPHRASE
|
||||
|
||||
@@ -86,7 +86,6 @@ from danswer.db.models import SearchSettings
|
||||
from danswer.db.models import User
|
||||
from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.db.search_settings import get_secondary_search_settings
|
||||
from danswer.file_processing.extract_file_text import convert_docx_to_txt
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
@@ -394,12 +393,6 @@ def upload_files(
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
file_type=file.content_type or "text/plain",
|
||||
)
|
||||
|
||||
if file.content_type and file.content_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
convert_docx_to_txt(file, file_store, file_path)
|
||||
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return FileUploadResponse(file_paths=deduped_file_paths)
|
||||
@@ -1017,18 +1010,37 @@ def get_connector_by_id(
|
||||
|
||||
|
||||
class BasicCCPairInfo(BaseModel):
|
||||
docs_indexed: int
|
||||
has_successful_run: bool
|
||||
source: DocumentSource
|
||||
|
||||
|
||||
@router.get("/connector-status")
|
||||
@router.get("/indexing-status")
|
||||
def get_basic_connector_indexing_status(
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BasicCCPairInfo]:
|
||||
cc_pairs = get_connector_credential_pairs(db_session)
|
||||
cc_pair_identifiers = [
|
||||
ConnectorCredentialPairIdentifier(
|
||||
connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
]
|
||||
document_count_info = get_document_counts_for_cc_pairs(
|
||||
db_session=db_session,
|
||||
cc_pair_identifiers=cc_pair_identifiers,
|
||||
)
|
||||
cc_pair_to_document_cnt = {
|
||||
(connector_id, credential_id): cnt
|
||||
for connector_id, credential_id, cnt in document_count_info
|
||||
}
|
||||
return [
|
||||
BasicCCPairInfo(
|
||||
docs_indexed=cc_pair_to_document_cnt.get(
|
||||
(cc_pair.connector_id, cc_pair.credential_id)
|
||||
)
|
||||
or 0,
|
||||
has_successful_run=cc_pair.last_successful_index_time is not None,
|
||||
source=cc_pair.connector.source,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_curator_or_admin_user
|
||||
from danswer.auth.users import current_limited_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chat.prompt_builder.utils import build_dummy_prompt
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import NotificationType
|
||||
from danswer.db.engine import get_session
|
||||
@@ -34,6 +33,7 @@ from danswer.db.persona import update_persona_shared_users
|
||||
from danswer.db.persona import update_persona_visibility
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.llm.answering.prompts.utils import build_dummy_prompt
|
||||
from danswer.server.features.persona.models import CreatePersonaRequest
|
||||
from danswer.server.features.persona.models import ImageGenerationToolStatus
|
||||
from danswer.server.features.persona.models import PersonaCategoryCreate
|
||||
|
||||
@@ -194,11 +194,11 @@ def bulk_invite_users(
|
||||
)
|
||||
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
new_invited_emails = []
|
||||
normalized_emails = []
|
||||
try:
|
||||
for email in emails:
|
||||
email_info = validate_email(email)
|
||||
new_invited_emails.append(email_info.normalized)
|
||||
normalized_emails.append(email_info.normalized) # type: ignore
|
||||
|
||||
except (EmailUndeliverableError, EmailNotValidError) as e:
|
||||
raise HTTPException(
|
||||
@@ -210,7 +210,7 @@ def bulk_invite_users(
|
||||
try:
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
)(normalized_emails, tenant_id)
|
||||
|
||||
except IntegrityError as e:
|
||||
if isinstance(e.orig, UniqueViolation):
|
||||
@@ -224,7 +224,7 @@ def bulk_invite_users(
|
||||
|
||||
initial_invited_users = get_invited_users()
|
||||
|
||||
all_emails = list(set(new_invited_emails) | set(initial_invited_users))
|
||||
all_emails = list(set(normalized_emails) | set(initial_invited_users))
|
||||
number_of_invited_users = write_invited_users(all_emails)
|
||||
|
||||
if not MULTI_TENANT:
|
||||
@@ -236,7 +236,7 @@ def bulk_invite_users(
|
||||
)(CURRENT_TENANT_ID_CONTEXTVAR.get(), get_total_users_count(db_session))
|
||||
if ENABLE_EMAIL_INVITES:
|
||||
try:
|
||||
for email in new_invited_emails:
|
||||
for email in all_emails:
|
||||
send_user_email_invite(email, current_user)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending email invite to invited users: {e}")
|
||||
@@ -250,7 +250,7 @@ def bulk_invite_users(
|
||||
write_invited_users(initial_invited_users) # Reset to original state
|
||||
fetch_ee_implementation_or_noop(
|
||||
"danswer.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
)(normalized_emails, tenant_id)
|
||||
raise e
|
||||
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import asyncio
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
@@ -24,9 +23,6 @@ from danswer.auth.users import current_user
|
||||
from danswer.chat.chat_utils import create_chat_chain
|
||||
from danswer.chat.chat_utils import extract_headers
|
||||
from danswer.chat.process_message import stream_chat_message
|
||||
from danswer.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.configs.constants import MessageType
|
||||
@@ -51,11 +47,13 @@ from danswer.db.models import User
|
||||
from danswer.db.persona import get_persona_by_id
|
||||
from danswer.document_index.document_index_utils import get_both_index_names
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.file_processing.extract_file_text import docx_to_txt_filename
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import FileDescriptor
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
@@ -720,18 +718,6 @@ def fetch_chat_file(
|
||||
if not file_record:
|
||||
raise HTTPException(status_code=404, detail="File not found")
|
||||
|
||||
original_file_name = file_record.display_name
|
||||
if file_record.file_type.startswith(
|
||||
"application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
# Check if a converted text file exists for .docx files
|
||||
txt_file_name = docx_to_txt_filename(original_file_name)
|
||||
txt_file_id = os.path.join(os.path.dirname(file_id), txt_file_name)
|
||||
txt_file_record = file_store.read_file_record(txt_file_id)
|
||||
if txt_file_record:
|
||||
file_record = txt_file_record
|
||||
file_id = txt_file_id
|
||||
|
||||
media_type = file_record.file_type
|
||||
file_io = file_store.read_file(file_id, mode="b")
|
||||
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -23,9 +22,6 @@ from danswer.llm.override_models import LLMOverride
|
||||
from danswer.llm.override_models import PromptOverride
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class SourceTag(Tag):
|
||||
source: DocumentSource
|
||||
|
||||
@@ -4,7 +4,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP
|
||||
from danswer.configs.app_configs import MANAGED_VESPA
|
||||
from danswer.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP
|
||||
from danswer.configs.constants import KV_REINDEX_KEY
|
||||
from danswer.configs.constants import KV_SEARCH_SETTINGS
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
@@ -222,13 +221,13 @@ def setup_vespa(
|
||||
document_index: DocumentIndex,
|
||||
index_setting: IndexingSetting,
|
||||
secondary_index_setting: IndexingSetting | None,
|
||||
num_attempts: int = VESPA_NUM_ATTEMPTS_ON_STARTUP,
|
||||
) -> bool:
|
||||
# Vespa startup is a bit slow, so give it a few seconds
|
||||
WAIT_SECONDS = 5
|
||||
for x in range(num_attempts):
|
||||
VESPA_ATTEMPTS = 5
|
||||
for x in range(VESPA_ATTEMPTS):
|
||||
try:
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{num_attempts})...")
|
||||
logger.notice(f"Setting up Vespa (attempt {x+1}/{VESPA_ATTEMPTS})...")
|
||||
document_index.ensure_indices_exist(
|
||||
index_embedding_dim=index_setting.model_dim,
|
||||
secondary_index_embedding_dim=secondary_index_setting.model_dim
|
||||
@@ -245,7 +244,7 @@ def setup_vespa(
|
||||
time.sleep(WAIT_SECONDS)
|
||||
|
||||
logger.error(
|
||||
f"Vespa setup did not succeed. Attempt limit reached. ({num_attempts})"
|
||||
f"Vespa setup did not succeed. Attempt limit reached. ({VESPA_ATTEMPTS})"
|
||||
)
|
||||
return False
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from danswer.llm.utils import message_to_prompt_and_imgs
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.tool_implementations.custom.custom_tool import (
|
||||
CustomToolCallSummary,
|
||||
)
|
||||
|
||||
@@ -3,13 +3,13 @@ from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.utils.special_types import JSON_ro
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
@@ -5,10 +5,6 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_BASE
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_KEY
|
||||
from danswer.configs.app_configs import AZURE_DALLE_API_VERSION
|
||||
@@ -23,6 +19,10 @@ from danswer.db.llm import fetch_existing_llm_providers
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.natural_language_processing.utils import get_tokenizer
|
||||
|
||||
@@ -15,14 +15,14 @@ from langchain_core.messages import SystemMessage
|
||||
from pydantic import BaseModel
|
||||
from requests import JSONDecodeError
|
||||
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.constants import FileOrigin
|
||||
from danswer.db.engine import get_session_with_default_tenant
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.file_store.models import ChatFileType
|
||||
from danswer.file_store.models import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.tools.base_tool import BaseTool
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import CHAT_SESSION_ID_PLACEHOLDER
|
||||
|
||||
@@ -4,16 +4,14 @@ from enum import Enum
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from litellm import image_generation # type: ignore
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.configs.tool_configs import IMAGE_GENERATION_OUTPUT_FORMAT
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
@@ -58,18 +56,9 @@ Follow Up Input:
|
||||
""".strip()
|
||||
|
||||
|
||||
class ImageFormat(str, Enum):
|
||||
URL = "url"
|
||||
BASE64 = "b64_json"
|
||||
|
||||
|
||||
_DEFAULT_OUTPUT_FORMAT = ImageFormat(IMAGE_GENERATION_OUTPUT_FORMAT)
|
||||
|
||||
|
||||
class ImageGenerationResponse(BaseModel):
|
||||
revised_prompt: str
|
||||
url: str | None
|
||||
image_data: str | None
|
||||
url: str
|
||||
|
||||
|
||||
class ImageShape(str, Enum):
|
||||
@@ -91,7 +80,6 @@ class ImageGenerationTool(Tool):
|
||||
model: str = "dall-e-3",
|
||||
num_imgs: int = 2,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
output_format: ImageFormat = _DEFAULT_OUTPUT_FORMAT,
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.api_base = api_base
|
||||
@@ -101,7 +89,6 @@ class ImageGenerationTool(Tool):
|
||||
self.num_imgs = num_imgs
|
||||
|
||||
self.additional_headers = additional_headers
|
||||
self.output_format = output_format
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
@@ -181,7 +168,7 @@ class ImageGenerationTool(Tool):
|
||||
)
|
||||
|
||||
return build_content_with_imgs(
|
||||
message=json.dumps(
|
||||
json.dumps(
|
||||
[
|
||||
{
|
||||
"revised_prompt": image_generation.revised_prompt,
|
||||
@@ -190,10 +177,13 @@ class ImageGenerationTool(Tool):
|
||||
for image_generation in image_generations
|
||||
]
|
||||
),
|
||||
# NOTE: we can't pass in the image URLs here, since OpenAI doesn't allow
|
||||
# Tool messages to contain images
|
||||
# img_urls=[image_generation.url for image_generation in image_generations],
|
||||
)
|
||||
|
||||
def _generate_image(
|
||||
self, prompt: str, shape: ImageShape, format: ImageFormat
|
||||
self, prompt: str, shape: ImageShape
|
||||
) -> ImageGenerationResponse:
|
||||
if shape == ImageShape.LANDSCAPE:
|
||||
size = "1792x1024"
|
||||
@@ -207,32 +197,20 @@ class ImageGenerationTool(Tool):
|
||||
prompt=prompt,
|
||||
model=self.model,
|
||||
api_key=self.api_key,
|
||||
# need to pass in None rather than empty str
|
||||
api_base=self.api_base or None,
|
||||
api_version=self.api_version or None,
|
||||
size=size,
|
||||
n=1,
|
||||
response_format=format,
|
||||
extra_headers=build_llm_extra_headers(self.additional_headers),
|
||||
)
|
||||
|
||||
if format == ImageFormat.URL:
|
||||
url = response.data[0]["url"]
|
||||
image_data = None
|
||||
else:
|
||||
url = None
|
||||
image_data = response.data[0]["b64_json"]
|
||||
|
||||
return ImageGenerationResponse(
|
||||
revised_prompt=response.data[0]["revised_prompt"],
|
||||
url=url,
|
||||
image_data=image_data,
|
||||
url=response.data[0]["url"],
|
||||
)
|
||||
|
||||
except requests.RequestException as e:
|
||||
logger.error(f"Error fetching or converting image: {e}")
|
||||
raise ValueError("Failed to fetch or convert the generated image")
|
||||
except Exception as e:
|
||||
logger.debug(f"Error occurred during image generation: {e}")
|
||||
logger.debug(f"Error occured during image generation: {e}")
|
||||
|
||||
error_message = str(e)
|
||||
if "OpenAIException" in str(type(e)):
|
||||
@@ -257,8 +235,9 @@ class ImageGenerationTool(Tool):
|
||||
def run(self, **kwargs: str) -> Generator[ToolResponse, None, None]:
|
||||
prompt = cast(str, kwargs["prompt"])
|
||||
shape = ImageShape(kwargs.get("shape", ImageShape.SQUARE))
|
||||
format = self.output_format
|
||||
|
||||
# dalle3 only supports 1 image at a time, which is why we have to
|
||||
# parallelize this via threading
|
||||
results = cast(
|
||||
list[ImageGenerationResponse],
|
||||
run_functions_tuples_in_parallel(
|
||||
@@ -268,7 +247,6 @@ class ImageGenerationTool(Tool):
|
||||
(
|
||||
prompt,
|
||||
shape,
|
||||
format,
|
||||
),
|
||||
)
|
||||
for _ in range(self.num_imgs)
|
||||
@@ -310,17 +288,11 @@ class ImageGenerationTool(Tool):
|
||||
if img_generation_response is None:
|
||||
raise ValueError("No image generation response found")
|
||||
|
||||
img_urls = [img.url for img in img_generation_response if img.url is not None]
|
||||
b64_imgs = [
|
||||
img.image_data
|
||||
for img in img_generation_response
|
||||
if img.image_data is not None
|
||||
]
|
||||
img_urls = [img.url for img in img_generation_response]
|
||||
prompt_builder.update_user_prompt(
|
||||
build_image_generation_user_prompt(
|
||||
query=prompt_builder.get_user_message_content(),
|
||||
img_urls=img_urls,
|
||||
b64_imgs=b64_imgs,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -11,14 +11,11 @@ Can you please summarize them in a sentence or two? Do NOT include image urls or
|
||||
|
||||
|
||||
def build_image_generation_user_prompt(
|
||||
query: str,
|
||||
img_urls: list[str] | None = None,
|
||||
b64_imgs: list[str] | None = None,
|
||||
query: str, img_urls: list[str] | None = None
|
||||
) -> HumanMessage:
|
||||
return HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
message=IMG_GENERATION_SUMMARY_PROMPT.format(query=query).strip(),
|
||||
b64_imgs=b64_imgs,
|
||||
img_urls=img_urls,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -7,15 +7,15 @@ from typing import cast
|
||||
import httpx
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.context.search.models import SearchDoc
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.chat_prompts import INTERNET_SEARCH_QUERY_REPHRASE
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
|
||||
@@ -7,19 +7,10 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chat.chat_utils import llm_doc_from_inference_section
|
||||
from danswer.chat.llm_response_handler import LLMCall
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import ContextualPruningConfig
|
||||
from danswer.chat.models import DanswerContext
|
||||
from danswer.chat.models import DanswerContexts
|
||||
from danswer.chat.models import DocumentPruningConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import SectionRelevancePiece
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.chat.prune_and_merge import prune_and_merge_sections
|
||||
from danswer.chat.prune_and_merge import prune_sections
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from danswer.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
@@ -34,8 +25,17 @@ from danswer.context.search.models import SearchRequest
|
||||
from danswer.context.search.pipeline import SearchPipeline
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.answering.llm_response_handler import LLMCall
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import ContextualPruningConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.citations_prompt import compute_max_llm_input_tokens
|
||||
from danswer.llm.answering.prune_and_merge import prune_and_merge_sections
|
||||
from danswer.llm.answering.prune_and_merge import prune_sections
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.secondary_llm_flows.choose_search import check_if_need_search
|
||||
from danswer.secondary_llm_flows.query_expansion import history_based_query_rephrase
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
|
||||
@@ -2,15 +2,15 @@ from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.chat.prompt_builder.citations_prompt import (
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
build_citations_system_message,
|
||||
)
|
||||
from danswer.chat.prompt_builder.citations_prompt import build_citations_user_message
|
||||
from danswer.chat.prompt_builder.quotes_prompt import build_quotes_user_message
|
||||
from danswer.llm.answering.prompts.citations_prompt import build_citations_user_message
|
||||
from danswer.llm.answering.prompts.quotes_prompt import build_quotes_user_message
|
||||
from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
@@ -3,8 +3,8 @@ from typing import Any
|
||||
|
||||
from danswer.chat.chat_utils import combine_message_chain
|
||||
from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.models import PreviousMessage
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.constants import GENERAL_SEP_PAT
|
||||
from danswer.tools.tool import Tool
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
import base64
|
||||
|
||||
|
||||
def get_image_type_from_bytes(raw_b64_bytes: bytes) -> str:
|
||||
magic_number = raw_b64_bytes[:4]
|
||||
|
||||
if magic_number.startswith(b"\x89PNG"):
|
||||
mime_type = "image/png"
|
||||
elif magic_number.startswith(b"\xFF\xD8"):
|
||||
mime_type = "image/jpeg"
|
||||
elif magic_number.startswith(b"GIF8"):
|
||||
mime_type = "image/gif"
|
||||
elif magic_number.startswith(b"RIFF") and raw_b64_bytes[8:12] == b"WEBP":
|
||||
mime_type = "image/webp"
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported image format - only PNG, JPEG, " "GIF, and WEBP are supported."
|
||||
)
|
||||
|
||||
return mime_type
|
||||
|
||||
|
||||
def get_image_type(raw_b64_string: str) -> str:
|
||||
binary_data = base64.b64decode(raw_b64_string)
|
||||
return get_image_type_from_bytes(binary_data)
|
||||
@@ -28,6 +28,3 @@ JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None)
|
||||
# Super Users
|
||||
SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", '["pablo@danswer.ai"]'))
|
||||
SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key")
|
||||
|
||||
OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "")
|
||||
OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "")
|
||||
|
||||
@@ -170,67 +170,3 @@ def fetch_danswerbot_analytics(
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def fetch_persona_message_analytics(
|
||||
db_session: Session,
|
||||
persona_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""Gets the daily message counts for a specific persona within the given time range."""
|
||||
query = (
|
||||
select(
|
||||
func.count(ChatMessage.id),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == persona_id,
|
||||
ChatSession.persona_id == persona_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
|
||||
def fetch_persona_unique_users(
|
||||
db_session: Session,
|
||||
persona_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
) -> list[tuple[int, datetime.date]]:
|
||||
"""Gets the daily unique user counts for a specific persona within the given time range."""
|
||||
query = (
|
||||
select(
|
||||
func.count(func.distinct(ChatSession.user_id)),
|
||||
cast(ChatMessage.time_sent, Date),
|
||||
)
|
||||
.join(
|
||||
ChatSession,
|
||||
ChatMessage.chat_session_id == ChatSession.id,
|
||||
)
|
||||
.where(
|
||||
or_(
|
||||
ChatMessage.alternate_assistant_id == persona_id,
|
||||
ChatSession.persona_id == persona_id,
|
||||
),
|
||||
ChatMessage.time_sent >= start,
|
||||
ChatMessage.time_sent <= end,
|
||||
ChatMessage.message_type == MessageType.ASSISTANT,
|
||||
)
|
||||
.group_by(cast(ChatMessage.time_sent, Date))
|
||||
.order_by(cast(ChatMessage.time_sent, Date))
|
||||
)
|
||||
|
||||
return [tuple(row) for row in db_session.execute(query).all()]
|
||||
|
||||
@@ -242,9 +242,7 @@ def _fetch_all_page_restrictions_for_space(
|
||||
)
|
||||
continue
|
||||
|
||||
logger.warning(
|
||||
f"No permissions found for document {slim_doc.id} in space {space_key}"
|
||||
)
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
|
||||
logger.debug("Finished fetching all page restrictions for space")
|
||||
return document_restrictions
|
||||
|
||||
@@ -26,7 +26,6 @@ from ee.danswer.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.danswer.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.danswer.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.danswer.server.oauth import router as oauth_router
|
||||
from ee.danswer.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@@ -120,8 +119,6 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
application, enterprise_settings_admin_router
|
||||
|
||||
@@ -11,16 +11,11 @@ from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from ee.danswer.db.analytics import fetch_danswerbot_analytics
|
||||
from ee.danswer.db.analytics import fetch_per_user_query_analytics
|
||||
from ee.danswer.db.analytics import fetch_persona_message_analytics
|
||||
from ee.danswer.db.analytics import fetch_persona_unique_users
|
||||
from ee.danswer.db.analytics import fetch_query_analytics
|
||||
|
||||
router = APIRouter(prefix="/analytics")
|
||||
|
||||
|
||||
_DEFAULT_LOOKBACK_DAYS = 30
|
||||
|
||||
|
||||
class QueryAnalyticsResponse(BaseModel):
|
||||
total_queries: int
|
||||
total_likes: int
|
||||
@@ -38,7 +33,7 @@ def get_query_analytics(
|
||||
daily_query_usage_info = fetch_query_analytics(
|
||||
start=start
|
||||
or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=30)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.datetime.utcnow(),
|
||||
db_session=db_session,
|
||||
@@ -69,7 +64,7 @@ def get_user_analytics(
|
||||
daily_query_usage_info_per_user = fetch_per_user_query_analytics(
|
||||
start=start
|
||||
or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=30)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.datetime.utcnow(),
|
||||
db_session=db_session,
|
||||
@@ -103,7 +98,7 @@ def get_danswerbot_analytics(
|
||||
daily_danswerbot_info = fetch_danswerbot_analytics(
|
||||
start=start
|
||||
or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=30)
|
||||
), # default is 30d lookback
|
||||
end=end or datetime.datetime.utcnow(),
|
||||
db_session=db_session,
|
||||
@@ -120,74 +115,3 @@ def get_danswerbot_analytics(
|
||||
]
|
||||
|
||||
return resolution_results
|
||||
|
||||
|
||||
class PersonaMessageAnalyticsResponse(BaseModel):
|
||||
total_messages: int
|
||||
date: datetime.date
|
||||
persona_id: int
|
||||
|
||||
|
||||
@router.get("/admin/persona/messages")
|
||||
def get_persona_messages(
|
||||
persona_id: int,
|
||||
start: datetime.datetime | None = None,
|
||||
end: datetime.datetime | None = None,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaMessageAnalyticsResponse]:
|
||||
"""Fetch daily message counts for a single persona within the given time range."""
|
||||
start = start or (
|
||||
datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS)
|
||||
)
|
||||
end = end or datetime.datetime.utcnow()
|
||||
|
||||
persona_message_counts = []
|
||||
for count, date in fetch_persona_message_analytics(
|
||||
db_session=db_session,
|
||||
persona_id=persona_id,
|
||||
start=start,
|
||||
end=end,
|
||||
):
|
||||
persona_message_counts.append(
|
||||
PersonaMessageAnalyticsResponse(
|
||||
total_messages=count,
|
||||
date=date,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
)
|
||||
|
||||
return persona_message_counts
|
||||
|
||||
|
||||
class PersonaUniqueUsersResponse(BaseModel):
|
||||
unique_users: int
|
||||
date: datetime.date
|
||||
persona_id: int
|
||||
|
||||
|
||||
@router.get("/admin/persona/unique-users")
|
||||
def get_persona_unique_users(
|
||||
persona_id: int,
|
||||
start: datetime.datetime,
|
||||
end: datetime.datetime,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[PersonaUniqueUsersResponse]:
|
||||
"""Get unique users per day for a single persona."""
|
||||
unique_user_counts = []
|
||||
daily_counts = fetch_persona_unique_users(
|
||||
db_session=db_session,
|
||||
persona_id=persona_id,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
for count, date in daily_counts:
|
||||
unique_user_counts.append(
|
||||
PersonaUniqueUsersResponse(
|
||||
unique_users=count,
|
||||
date=date,
|
||||
persona_id=persona_id,
|
||||
)
|
||||
)
|
||||
return unique_user_counts
|
||||
|
||||
@@ -1,423 +0,0 @@
|
||||
import base64
|
||||
import uuid
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.utils.logger import setup_logger
|
||||
from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_ID
|
||||
from ee.danswer.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/oauth")
|
||||
|
||||
|
||||
class SlackOAuth:
|
||||
# https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth
|
||||
# Example: https://api.slack.com/authentication/oauth-v2#exchanging
|
||||
|
||||
class OAuthSession(BaseModel):
|
||||
"""Stored in redis to be looked up on callback"""
|
||||
|
||||
email: str
|
||||
redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
CLIENT_ID = OAUTH_SLACK_CLIENT_ID
|
||||
CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET
|
||||
|
||||
TOKEN_URL = "https://slack.com/api/oauth.v2.access"
|
||||
|
||||
# SCOPE is per https://docs.danswer.dev/connectors/slack
|
||||
BOT_SCOPE = (
|
||||
"channels:history,"
|
||||
"channels:read,"
|
||||
"groups:history,"
|
||||
"groups:read,"
|
||||
"channels:join,"
|
||||
"im:history,"
|
||||
"users:read,"
|
||||
"users:read.email,"
|
||||
"usergroups:read"
|
||||
)
|
||||
|
||||
REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback"
|
||||
DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
@classmethod
|
||||
def generate_oauth_url(cls, state: str) -> str:
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
"""dev mode workaround for localhost testing
|
||||
- https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
"""
|
||||
|
||||
url = (
|
||||
f"https://slack.com/oauth/v2/authorize"
|
||||
f"?client_id={cls.CLIENT_ID}"
|
||||
f"&redirect_uri={cls.DEV_REDIRECT_URI}"
|
||||
f"&scope={cls.BOT_SCOPE}"
|
||||
f"&state={state}"
|
||||
)
|
||||
return url
|
||||
|
||||
@classmethod
|
||||
def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
"""Temporary state to store in redis. to be looked up on auth response.
|
||||
Returns a json string.
|
||||
"""
|
||||
session = SlackOAuth.OAuthSession(
|
||||
email=email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
return session.model_dump_json()
|
||||
|
||||
@classmethod
|
||||
def parse_session(cls, session_json: str) -> OAuthSession:
|
||||
session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
return session
|
||||
|
||||
|
||||
# Work in progress
|
||||
# class ConfluenceCloudOAuth:
|
||||
# """work in progress"""
|
||||
|
||||
# # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/
|
||||
|
||||
# class OAuthSession(BaseModel):
|
||||
# """Stored in redis to be looked up on callback"""
|
||||
|
||||
# email: str
|
||||
# redirect_on_success: str | None # Where to send the user if OAuth flow succeeds
|
||||
|
||||
# CLIENT_ID = OAUTH_CONFLUENCE_CLIENT_ID
|
||||
# CLIENT_SECRET = OAUTH_CONFLUENCE_CLIENT_SECRET
|
||||
# TOKEN_URL = "https://auth.atlassian.com/oauth/token"
|
||||
|
||||
# # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/
|
||||
# CONFLUENCE_OAUTH_SCOPE = (
|
||||
# "read:confluence-props%20"
|
||||
# "read:confluence-content.all%20"
|
||||
# "read:confluence-content.summary%20"
|
||||
# "read:confluence-content.permission%20"
|
||||
# "read:confluence-user%20"
|
||||
# "read:confluence-groups%20"
|
||||
# "readonly:content.attachment:confluence"
|
||||
# )
|
||||
|
||||
# REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback"
|
||||
# DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}"
|
||||
|
||||
# # eventually for Confluence Data Center
|
||||
# # oauth_url = (
|
||||
# # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}"
|
||||
# # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}"
|
||||
# # f"&redirect_uri={redirectme_uri}"
|
||||
# # )
|
||||
|
||||
# @classmethod
|
||||
# def generate_oauth_url(cls, state: str) -> str:
|
||||
# return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def generate_dev_oauth_url(cls, state: str) -> str:
|
||||
# """dev mode workaround for localhost testing
|
||||
# - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https
|
||||
# """
|
||||
# return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state)
|
||||
|
||||
# @classmethod
|
||||
# def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str:
|
||||
# url = (
|
||||
# "https://auth.atlassian.com/authorize"
|
||||
# f"?audience=api.atlassian.com"
|
||||
# f"&client_id={cls.CLIENT_ID}"
|
||||
# f"&redirect_uri={redirect_uri}"
|
||||
# f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}"
|
||||
# f"&state={state}"
|
||||
# "&response_type=code"
|
||||
# "&prompt=consent"
|
||||
# )
|
||||
# return url
|
||||
|
||||
# @classmethod
|
||||
# def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str:
|
||||
# """Temporary state to store in redis. to be looked up on auth response.
|
||||
# Returns a json string.
|
||||
# """
|
||||
# session = ConfluenceCloudOAuth.OAuthSession(
|
||||
# email=email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# return session.model_dump_json()
|
||||
|
||||
# @classmethod
|
||||
# def parse_session(cls, session_json: str) -> SlackOAuth.OAuthSession:
|
||||
# session = SlackOAuth.OAuthSession.model_validate_json(session_json)
|
||||
# return session
|
||||
|
||||
|
||||
@router.post("/prepare-authorization-request")
|
||||
def prepare_authorization_request(
|
||||
connector: DocumentSource,
|
||||
redirect_on_success: str | None,
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
"""Used by the frontend to generate the url for the user's browser during auth request.
|
||||
|
||||
Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/
|
||||
"""
|
||||
|
||||
oauth_uuid = uuid.uuid4()
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
oauth_state = (
|
||||
base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8")
|
||||
)
|
||||
|
||||
if connector == DocumentSource.SLACK:
|
||||
oauth_url = SlackOAuth.generate_oauth_url(oauth_state)
|
||||
session = SlackOAuth.session_dump_json(
|
||||
email=user.email, redirect_on_success=redirect_on_success
|
||||
)
|
||||
# elif connector == DocumentSource.CONFLUENCE:
|
||||
# oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state)
|
||||
# session = ConfluenceCloudOAuth.session_dump_json(
|
||||
# email=user.email, redirect_on_success=redirect_on_success
|
||||
# )
|
||||
# elif connector == DocumentSource.JIRA:
|
||||
# oauth_url = JiraCloudOAuth.generate_dev_oauth_url(oauth_state)
|
||||
# elif connector == DocumentSource.GOOGLE_DRIVE:
|
||||
# oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state)
|
||||
else:
|
||||
oauth_url = None
|
||||
|
||||
if not oauth_url:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=f"The document source type {connector} does not have OAuth implemented",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# 10 min is the max we want an oauth flow to be valid
|
||||
r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600)
|
||||
|
||||
return JSONResponse(content={"url": oauth_url})
|
||||
|
||||
|
||||
@router.post("/connector/slack/callback")
|
||||
def handle_slack_oauth_callback(
|
||||
code: str,
|
||||
state: str,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> JSONResponse:
|
||||
if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Slack client ID or client secret is not configured.",
|
||||
)
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# recover the state
|
||||
padded_state = state + "=" * (
|
||||
-len(state) % 4
|
||||
) # Add padding back (Base64 decoding requires padding)
|
||||
uuid_bytes = base64.urlsafe_b64decode(
|
||||
padded_state
|
||||
) # Decode the Base64 string back to bytes
|
||||
|
||||
# Convert bytes back to a UUID
|
||||
oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
session_json_bytes = cast(bytes, r.get(r_key))
|
||||
if not session_json_bytes:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}",
|
||||
)
|
||||
|
||||
session_json = session_json_bytes.decode("utf-8")
|
||||
try:
|
||||
session = SlackOAuth.parse_session(session_json)
|
||||
|
||||
# Exchange the authorization code for an access token
|
||||
response = requests.post(
|
||||
SlackOAuth.TOKEN_URL,
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
data={
|
||||
"client_id": SlackOAuth.CLIENT_ID,
|
||||
"client_secret": SlackOAuth.CLIENT_SECRET,
|
||||
"code": code,
|
||||
"redirect_uri": SlackOAuth.REDIRECT_URI,
|
||||
},
|
||||
)
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
if not response_data.get("ok"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Slack OAuth failed: {response_data.get('error')}",
|
||||
)
|
||||
|
||||
# Extract token and team information
|
||||
access_token: str = response_data.get("access_token")
|
||||
team_id: str = response_data.get("team", {}).get("id")
|
||||
authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
credential_info = CredentialBase(
|
||||
credential_json={"slack_bot_token": access_token},
|
||||
admin_public=True,
|
||||
source=DocumentSource.SLACK,
|
||||
name="Slack OAuth",
|
||||
)
|
||||
|
||||
create_credential(credential_info, user, db_session)
|
||||
except Exception as e:
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={
|
||||
"success": False,
|
||||
"message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
},
|
||||
)
|
||||
finally:
|
||||
r.delete(r_key)
|
||||
|
||||
# return the result
|
||||
return JSONResponse(
|
||||
content={
|
||||
"success": True,
|
||||
"message": "Slack OAuth completed successfully.",
|
||||
"team_id": team_id,
|
||||
"authed_user_id": authed_user_id,
|
||||
"redirect_on_success": session.redirect_on_success,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# Work in progress
|
||||
# @router.post("/connector/confluence/callback")
|
||||
# def handle_confluence_oauth_callback(
|
||||
# code: str,
|
||||
# state: str,
|
||||
# user: User = Depends(current_user),
|
||||
# db_session: Session = Depends(get_session),
|
||||
# tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
# ) -> JSONResponse:
|
||||
# if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET:
|
||||
# raise HTTPException(
|
||||
# status_code=500,
|
||||
# detail="Confluence client ID or client secret is not configured."
|
||||
# )
|
||||
|
||||
# r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# # recover the state
|
||||
# padded_state = state + '=' * (-len(state) % 4) # Add padding back (Base64 decoding requires padding)
|
||||
# uuid_bytes = base64.urlsafe_b64decode(padded_state) # Decode the Base64 string back to bytes
|
||||
|
||||
# # Convert bytes back to a UUID
|
||||
# oauth_uuid = uuid.UUID(bytes=uuid_bytes)
|
||||
# oauth_uuid_str = str(oauth_uuid)
|
||||
|
||||
# r_key = f"da_oauth:{oauth_uuid_str}"
|
||||
|
||||
# result = r.get(r_key)
|
||||
# if not result:
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"Confluence OAuth failed - OAuth state key not found: key={r_key}"
|
||||
# )
|
||||
|
||||
# try:
|
||||
# session = ConfluenceCloudOAuth.parse_session(result)
|
||||
|
||||
# # Exchange the authorization code for an access token
|
||||
# response = requests.post(
|
||||
# ConfluenceCloudOAuth.TOKEN_URL,
|
||||
# headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
# data={
|
||||
# "client_id": ConfluenceCloudOAuth.CLIENT_ID,
|
||||
# "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET,
|
||||
# "code": code,
|
||||
# "redirect_uri": ConfluenceCloudOAuth.DEV_REDIRECT_URI,
|
||||
# },
|
||||
# )
|
||||
|
||||
# response_data = response.json()
|
||||
|
||||
# if not response_data.get("ok"):
|
||||
# raise HTTPException(
|
||||
# status_code=400,
|
||||
# detail=f"ConfluenceCloudOAuth OAuth failed: {response_data.get('error')}"
|
||||
# )
|
||||
|
||||
# # Extract token and team information
|
||||
# access_token: str = response_data.get("access_token")
|
||||
# team_id: str = response_data.get("team", {}).get("id")
|
||||
# authed_user_id: str = response_data.get("authed_user", {}).get("id")
|
||||
|
||||
# credential_info = CredentialBase(
|
||||
# credential_json={"slack_bot_token": access_token},
|
||||
# admin_public=True,
|
||||
# source=DocumentSource.CONFLUENCE,
|
||||
# name="Confluence OAuth",
|
||||
# )
|
||||
|
||||
# logger.info(f"Slack access token: {access_token}")
|
||||
|
||||
# credential = create_credential(credential_info, user, db_session)
|
||||
|
||||
# logger.info(f"new_credential_id={credential.id}")
|
||||
# except Exception as e:
|
||||
# return JSONResponse(
|
||||
# status_code=500,
|
||||
# content={
|
||||
# "success": False,
|
||||
# "message": f"An error occurred during Slack OAuth: {str(e)}",
|
||||
# },
|
||||
# )
|
||||
# finally:
|
||||
# r.delete(r_key)
|
||||
|
||||
# # return the result
|
||||
# return JSONResponse(
|
||||
# content={
|
||||
# "success": True,
|
||||
# "message": "Slack OAuth completed successfully.",
|
||||
# "team_id": team_id,
|
||||
# "authed_user_id": authed_user_id,
|
||||
# "redirect_on_success": session.redirect_on_success,
|
||||
# }
|
||||
# )
|
||||
@@ -61,7 +61,7 @@ LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "danswer"
|
||||
# Enable generating persistent log files for local dev environments
|
||||
DEV_LOGGING_ENABLED = os.environ.get("DEV_LOGGING_ENABLED", "").lower() == "true"
|
||||
# notset, debug, info, notice, warning, error, or critical
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "info")
|
||||
LOG_LEVEL = os.environ.get("LOG_LEVEL", "notice")
|
||||
|
||||
# Timeout for API-based embedding models
|
||||
# NOTE: does not apply for Google VertexAI, since the python client doesn't
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.connectors.confluence.connector import ConfluenceConnector
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def confluence_connector() -> ConfluenceConnector:
|
||||
connector = ConfluenceConnector(
|
||||
wiki_base="https://danswerai.atlassian.net",
|
||||
is_cloud=True,
|
||||
)
|
||||
connector.load_credentials(
|
||||
{
|
||||
"confluence_access_token": os.environ["CONFLUENCE_ACCESS_TOKEN"],
|
||||
"confluence_username": os.environ["CONFLUENCE_USER_NAME"],
|
||||
}
|
||||
)
|
||||
return connector
|
||||
|
||||
|
||||
# This should never fail because even if the docs in the cloud change,
|
||||
# the full doc ids retrieved should always be a subset of the slim doc ids
|
||||
def test_confluence_connector_permissions(
|
||||
confluence_connector: ConfluenceConnector,
|
||||
) -> None:
|
||||
# Get all doc IDs from the full connector
|
||||
all_full_doc_ids = set()
|
||||
for doc_batch in confluence_connector.load_from_state():
|
||||
all_full_doc_ids.update([doc.id for doc in doc_batch])
|
||||
|
||||
# Get all doc IDs from the slim connector
|
||||
all_slim_doc_ids = set()
|
||||
for slim_doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
|
||||
|
||||
# The set of full doc IDs should be always be a subset of the slim doc IDs
|
||||
assert all_full_doc_ids.issubset(all_slim_doc_ids)
|
||||
@@ -2,10 +2,14 @@ import textwrap
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.chat.stream_processing.quotes_processing import match_quotes_to_docs
|
||||
from danswer.chat.stream_processing.quotes_processing import separate_answer_quotes
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
match_quotes_to_docs,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.quotes_processing import (
|
||||
separate_answer_quotes,
|
||||
)
|
||||
|
||||
|
||||
def test_passed_in_quotes() -> None:
|
||||
@@ -5,12 +5,12 @@ from unittest.mock import MagicMock
|
||||
import pytest
|
||||
from langchain_core.messages import SystemMessage
|
||||
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationConfig
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.prompt_builder.build import AnswerPromptBuilder
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.prompts.build import AnswerPromptBuilder
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.tools.models import ToolResponse
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
@@ -5,9 +5,11 @@ import pytest
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.llm.answering.stream_processing.citation_processing import (
|
||||
CitationProcessor,
|
||||
)
|
||||
from danswer.llm.answering.stream_processing.utils import DocumentIdOrderMapping
|
||||
|
||||
|
||||
"""
|
||||
@@ -11,21 +11,21 @@ from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import ToolCallChunk
|
||||
|
||||
from danswer.chat.answer import Answer
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.chat.models import StreamStopInfo
|
||||
from danswer.chat.models import StreamStopReason
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.models import ToolCallFinalResult
|
||||
from danswer.tools.models import ToolCallKickoff
|
||||
from danswer.tools.models import ToolResponse
|
||||
from tests.unit.danswer.chat.conftest import DEFAULT_SEARCH_ARGS
|
||||
from tests.unit.danswer.chat.conftest import QUERY
|
||||
from tests.unit.danswer.llm.answering.conftest import DEFAULT_SEARCH_ARGS
|
||||
from tests.unit.danswer.llm.answering.conftest import QUERY
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1,9 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from danswer.chat.prune_and_merge import _merge_sections
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.context.search.models import InferenceChunk
|
||||
from danswer.context.search.models import InferenceSection
|
||||
from danswer.llm.answering.prune_and_merge import _merge_sections
|
||||
|
||||
|
||||
# This large test accounts for all of the following:
|
||||
@@ -5,10 +5,10 @@ from unittest.mock import Mock
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from danswer.chat.answer import Answer
|
||||
from danswer.chat.answer import AnswerStream
|
||||
from danswer.chat.models import AnswerStyleConfig
|
||||
from danswer.chat.models import PromptConfig
|
||||
from danswer.llm.answering.answer import Answer
|
||||
from danswer.llm.answering.answer import AnswerStream
|
||||
from danswer.llm.answering.models import AnswerStyleConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.tools.force import ForceUseTool
|
||||
from danswer.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from tests.regression.answer_quality.run_qa import _process_and_write_query_results
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4
|
||||
image: danswer/danswer-backend-cloud:v0.15.0-cloud.beta.0
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4
|
||||
image: danswer/danswer-backend-cloud:v0.15.0-cloud.beta.0
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4
|
||||
image: danswer/danswer-backend-cloud:v0.15.0-cloud.beta.0
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4
|
||||
image: danswer/danswer-backend-cloud:v0.15.0-cloud.beta.0
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: danswer/danswer-backend-cloud:v0.14.0-cloud.beta.4
|
||||
image: danswer/danswer-backend-cloud:v0.15.0-cloud.beta.0
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
|
||||
27
web/package-lock.json
generated
27
web/package-lock.json
generated
@@ -73,6 +73,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/playwright": "^0.10.0",
|
||||
"@playwright/test": "^1.49.0",
|
||||
"@tailwindcss/typography": "^0.5.10",
|
||||
"chromatic": "^11.18.1",
|
||||
"eslint": "^8.48.0",
|
||||
@@ -2573,14 +2574,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/@playwright/test": {
|
||||
"version": "1.48.2",
|
||||
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.48.2.tgz",
|
||||
"integrity": "sha512-54w1xCWfXuax7dz4W2M9uw0gDyh+ti/0K/MxcCUxChFh37kkdxPdfZDw5QBbuPUJHr1CiHJ1hXgSs+GgeQc5Zw==",
|
||||
"version": "1.49.0",
|
||||
"resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.49.0.tgz",
|
||||
"integrity": "sha512-DMulbwQURa8rNIQrf94+jPJQ4FmOVdpE5ZppRNvWVjvhC+6sOeo28r8MgIpQRYouXRtt/FCCXU7zn20jnHR4Qw==",
|
||||
"devOptional": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"playwright": "1.48.2"
|
||||
"playwright": "1.49.0"
|
||||
},
|
||||
"bin": {
|
||||
"playwright": "cli.js"
|
||||
@@ -13329,14 +13329,13 @@
|
||||
}
|
||||
},
|
||||
"node_modules/playwright": {
|
||||
"version": "1.48.2",
|
||||
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.48.2.tgz",
|
||||
"integrity": "sha512-NjYvYgp4BPmiwfe31j4gHLa3J7bD2WiBz8Lk2RoSsmX38SVIARZ18VYjxLjAcDsAhA+F4iSEXTSGgjua0rrlgQ==",
|
||||
"version": "1.49.0",
|
||||
"resolved": "https://registry.npmjs.org/playwright/-/playwright-1.49.0.tgz",
|
||||
"integrity": "sha512-eKpmys0UFDnfNb3vfsf8Vx2LEOtflgRebl0Im2eQQnYMA4Aqd+Zw8bEOB+7ZKvN76901mRnqdsiOGKxzVTbi7A==",
|
||||
"devOptional": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"playwright-core": "1.48.2"
|
||||
"playwright-core": "1.49.0"
|
||||
},
|
||||
"bin": {
|
||||
"playwright": "cli.js"
|
||||
@@ -13349,12 +13348,11 @@
|
||||
}
|
||||
},
|
||||
"node_modules/playwright-core": {
|
||||
"version": "1.48.2",
|
||||
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.48.2.tgz",
|
||||
"integrity": "sha512-sjjw+qrLFlriJo64du+EK0kJgZzoQPsabGF4lBvsid+3CNIZIYLgnMj9V6JY5VhM2Peh20DJWIVpVljLLnlawA==",
|
||||
"version": "1.49.0",
|
||||
"resolved": "https://registry.npmjs.org/playwright-core/-/playwright-core-1.49.0.tgz",
|
||||
"integrity": "sha512-R+3KKTQF3npy5GTiKH/T+kdhoJfJojjHESR1YEWhYuEKRVfVaxH3+4+GvXE5xyCngCxhxnykk0Vlah9v8fs3jA==",
|
||||
"devOptional": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"playwright-core": "cli.js"
|
||||
},
|
||||
@@ -13373,7 +13371,6 @@
|
||||
"os": [
|
||||
"darwin"
|
||||
],
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": "^8.16.0 || ^10.6.0 || >=11.0.0"
|
||||
}
|
||||
|
||||
@@ -75,6 +75,7 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/playwright": "^0.10.0",
|
||||
"@playwright/test": "^1.49.0",
|
||||
"@tailwindcss/typography": "^0.5.10",
|
||||
"chromatic": "^11.18.1",
|
||||
"eslint": "^8.48.0",
|
||||
|
||||
@@ -7,7 +7,6 @@ import { useState } from "react";
|
||||
import { SlackTokensForm } from "./SlackTokensForm";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
|
||||
export const NewSlackBotForm = ({}: {}) => {
|
||||
const [formValues] = useState({
|
||||
@@ -22,7 +21,7 @@ export const NewSlackBotForm = ({}: {}) => {
|
||||
return (
|
||||
<div>
|
||||
<AdminPageTitle
|
||||
icon={<SourceIcon iconSize={36} sourceType={ValidSources.Slack} />}
|
||||
icon={<SourceIcon iconSize={36} sourceType={"slack"} />}
|
||||
title="New Slack Bot"
|
||||
/>
|
||||
<CardSection>
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { SlackBot, ValidSources } from "@/lib/types";
|
||||
import { SlackBot } from "@/lib/types";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { ChevronDown, ChevronRight } from "lucide-react";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
@@ -78,7 +78,7 @@ export const ExistingSlackBotForm = ({
|
||||
<div className="flex items-center justify-between h-14">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="my-auto">
|
||||
<SourceIcon iconSize={32} sourceType={ValidSources.Slack} />
|
||||
<SourceIcon iconSize={32} sourceType={"slack"} />
|
||||
</div>
|
||||
<div className="ml-1">
|
||||
<EditableStringFieldDisplay
|
||||
|
||||
@@ -276,7 +276,13 @@ export const SlackChannelConfigCreationForm = ({
|
||||
|
||||
{showAdvancedOptions && (
|
||||
<div className="mt-4">
|
||||
<div className="w-64 mb-4">
|
||||
<BooleanFormField
|
||||
name="show_continue_in_web_ui"
|
||||
removeIndent
|
||||
label="Show Continue in Web UI button"
|
||||
tooltip="If set, will show a button at the bottom of the response that allows the user to continue the conversation in the Danswer Web UI"
|
||||
/>
|
||||
<div className="w-64 mb-4 mt-4">
|
||||
<SelectorFormField
|
||||
name="response_type"
|
||||
label="Answer Type"
|
||||
@@ -288,12 +294,6 @@ export const SlackChannelConfigCreationForm = ({
|
||||
/>
|
||||
</div>
|
||||
|
||||
<BooleanFormField
|
||||
name="show_continue_in_web_ui"
|
||||
removeIndent
|
||||
label="Show Continue in Web UI button"
|
||||
tooltip="If set, will show a button at the bottom of the response that allows the user to continue the conversation in the Danswer Web UI"
|
||||
/>
|
||||
<div className="flex flex-col space-y-3 mt-2">
|
||||
<BooleanFormField
|
||||
name="still_need_help_enabled"
|
||||
@@ -325,8 +325,8 @@ export const SlackChannelConfigCreationForm = ({
|
||||
<BooleanFormField
|
||||
name="answer_validity_check_enabled"
|
||||
removeIndent
|
||||
label="Only respond if citations found"
|
||||
tooltip="If set, will only answer questions where the model successfully produces citations"
|
||||
label="Hide Non-Answers"
|
||||
tooltip="If set, will only answer questions that the model determines it can answer"
|
||||
/>
|
||||
<BooleanFormField
|
||||
name="questionmark_prefilter_enabled"
|
||||
|
||||
@@ -3,7 +3,7 @@ import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { DocumentSet, SlackChannelConfig, ValidSources } from "@/lib/types";
|
||||
import { DocumentSet, SlackChannelConfig } from "@/lib/types";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import {
|
||||
@@ -84,7 +84,7 @@ async function EditslackChannelConfigPage(props: {
|
||||
|
||||
<BackButton />
|
||||
<AdminPageTitle
|
||||
icon={<SourceIcon sourceType={ValidSources.Slack} iconSize={32} />}
|
||||
icon={<SourceIcon sourceType={"slack"} iconSize={32} />}
|
||||
title="Edit Slack Channel Config"
|
||||
/>
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { DocumentSet, ValidSources } from "@/lib/types";
|
||||
import { DocumentSet } from "@/lib/types";
|
||||
import { BackButton } from "@/components/BackButton";
|
||||
import { fetchAssistantsSS } from "@/lib/assistants/fetchAssistantsSS";
|
||||
import {
|
||||
@@ -59,7 +59,7 @@ async function NewChannelConfigPage(props: {
|
||||
<div className="container mx-auto">
|
||||
<BackButton />
|
||||
<AdminPageTitle
|
||||
icon={<SourceIcon iconSize={32} sourceType={ValidSources.Slack} />}
|
||||
icon={<SourceIcon iconSize={32} sourceType={"slack"} />}
|
||||
title="Configure DanswerBot for Slack Channel"
|
||||
/>
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ import Link from "next/link";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { SlackBotTable } from "./SlackBotTable";
|
||||
import { useSlackBots } from "./[bot-id]/hooks";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
|
||||
const Main = () => {
|
||||
const {
|
||||
@@ -104,7 +103,7 @@ const Page = () => {
|
||||
return (
|
||||
<div className="container mx-auto">
|
||||
<AdminPageTitle
|
||||
icon={<SourceIcon iconSize={36} sourceType={ValidSources.Slack} />}
|
||||
icon={<SourceIcon iconSize={36} sourceType={"slack"} />}
|
||||
title="Slack Bots"
|
||||
/>
|
||||
<InstantSSRAutoRefresh />
|
||||
|
||||
@@ -25,13 +25,13 @@ function buildConfigEntries(
|
||||
obj: any,
|
||||
sourceType: ValidSources
|
||||
): { [key: string]: string } {
|
||||
if (sourceType === ValidSources.File) {
|
||||
if (sourceType === "file") {
|
||||
return obj.file_locations
|
||||
? {
|
||||
file_names: obj.file_locations.map(getNameFromPath),
|
||||
}
|
||||
: {};
|
||||
} else if (sourceType === ValidSources.GoogleSites) {
|
||||
} else if (sourceType === "google_sites") {
|
||||
return {
|
||||
base_url: obj.base_url,
|
||||
};
|
||||
|
||||
@@ -30,7 +30,7 @@ import { Button } from "@/components/ui/button";
|
||||
// since the uploaded files are cleaned up after some period of time
|
||||
// re-indexing will not work for the file connector. Also, it would not
|
||||
// make sense to re-index, since the files will not have changed.
|
||||
const CONNECTOR_TYPES_THAT_CANT_REINDEX: ValidSources[] = [ValidSources.File];
|
||||
const CONNECTOR_TYPES_THAT_CANT_REINDEX: ValidSources[] = ["file"];
|
||||
|
||||
function Main({ ccPairId }: { ccPairId: number }) {
|
||||
const router = useRouter(); // Initialize the router
|
||||
|
||||
@@ -9,9 +9,9 @@ import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { buildSimilarCredentialInfoURL } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { useFormContext } from "@/components/context/FormContext";
|
||||
import { getSourceDisplayName, getSourceMetadata } from "@/lib/sources";
|
||||
import { getSourceDisplayName } from "@/lib/sources";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { useEffect, useState } from "react";
|
||||
import { useState } from "react";
|
||||
import { deleteCredential, linkCredential } from "@/lib/credential";
|
||||
import { submitFiles } from "./pages/utils/files";
|
||||
import { submitGoogleSite } from "./pages/utils/google_site";
|
||||
@@ -43,8 +43,6 @@ import { Formik } from "formik";
|
||||
import NavigationRow from "./NavigationRow";
|
||||
import { useRouter } from "next/navigation";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils";
|
||||
import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
export interface AdvancedConfig {
|
||||
refreshFreq: number;
|
||||
pruneFreq: number;
|
||||
@@ -112,23 +110,6 @@ export default function AddConnector({
|
||||
}: {
|
||||
connector: ConfigurableSources;
|
||||
}) {
|
||||
const [currentPageUrl, setCurrentPageUrl] = useState<string | null>(null);
|
||||
const [oauthUrl, setOauthUrl] = useState<string | null>(null);
|
||||
const [isAuthorizing, setIsAuthorizing] = useState(false);
|
||||
const [isAuthorizeVisible, setIsAuthorizeVisible] = useState(false);
|
||||
useEffect(() => {
|
||||
if (typeof window !== "undefined") {
|
||||
setCurrentPageUrl(window.location.href);
|
||||
}
|
||||
|
||||
if (EE_ENABLED && NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
const sourceMetadata = getSourceMetadata(connector);
|
||||
if (sourceMetadata?.oauthSupported == true) {
|
||||
setIsAuthorizeVisible(true);
|
||||
}
|
||||
}
|
||||
}, []);
|
||||
|
||||
const router = useRouter();
|
||||
|
||||
// State for managing credentials and files
|
||||
@@ -154,13 +135,8 @@ export default function AddConnector({
|
||||
const configuration: ConnectionConfiguration = connectorConfigs[connector];
|
||||
|
||||
// Form context and popup management
|
||||
const {
|
||||
setFormStep,
|
||||
setAllowCreate: setAllowCreate,
|
||||
formStep,
|
||||
nextFormStep,
|
||||
prevFormStep,
|
||||
} = useFormContext();
|
||||
const { setFormStep, setAlowCreate, formStep, nextFormStep, prevFormStep } =
|
||||
useFormContext();
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
// Hooks for Google Drive and Gmail credentials
|
||||
@@ -216,7 +192,7 @@ export default function AddConnector({
|
||||
|
||||
const onSwap = async (selectedCredential: Credential<any>) => {
|
||||
setCurrentCredential(selectedCredential);
|
||||
setAllowCreate(true);
|
||||
setAlowCreate(true);
|
||||
setPopup({
|
||||
message: "Swapped credential successfully!",
|
||||
type: "success",
|
||||
@@ -228,37 +204,6 @@ export default function AddConnector({
|
||||
router.push("/admin/indexing/status?message=connector-created");
|
||||
};
|
||||
|
||||
const handleAuthorize = async () => {
|
||||
// authorize button handler
|
||||
// gets an auth url from the server and directs the user to it in a popup
|
||||
|
||||
if (!currentPageUrl) return;
|
||||
|
||||
setIsAuthorizing(true);
|
||||
try {
|
||||
const response = await prepareOAuthAuthorizationRequest(
|
||||
connector,
|
||||
currentPageUrl
|
||||
);
|
||||
if (response.url) {
|
||||
setOauthUrl(response.url);
|
||||
window.open(response.url, "_blank", "noopener,noreferrer");
|
||||
} else {
|
||||
setPopup({ message: "Failed to fetch OAuth URL", type: "error" });
|
||||
}
|
||||
} catch (error: unknown) {
|
||||
// Narrow the type of error
|
||||
if (error instanceof Error) {
|
||||
setPopup({ message: `Error: ${error.message}`, type: "error" });
|
||||
} else {
|
||||
// Handle non-standard errors
|
||||
setPopup({ message: "An unknown error occurred", type: "error" });
|
||||
}
|
||||
} finally {
|
||||
setIsAuthorizing(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<Formik
|
||||
initialValues={{
|
||||
@@ -440,31 +385,16 @@ export default function AddConnector({
|
||||
onSwitch={onSwap}
|
||||
/>
|
||||
{!createConnectorToggle && (
|
||||
<div className="mt-6 flex space-x-4">
|
||||
{/* Button to pop up a form to manually enter credentials */}
|
||||
<button
|
||||
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
|
||||
onClick={() =>
|
||||
setCreateConnectorToggle(
|
||||
(createConnectorToggle) => !createConnectorToggle
|
||||
)
|
||||
}
|
||||
>
|
||||
Create New
|
||||
</button>
|
||||
|
||||
{/* Button to sign in via OAuth */}
|
||||
<button
|
||||
onClick={handleAuthorize}
|
||||
className="mt-6 text-sm bg-blue-500 px-2 py-1.5 flex text-text-200 flex-none rounded"
|
||||
disabled={isAuthorizing}
|
||||
hidden={!isAuthorizeVisible}
|
||||
>
|
||||
{isAuthorizing
|
||||
? "Authorizing..."
|
||||
: `Authorize with ${getSourceDisplayName(connector)}`}
|
||||
</button>
|
||||
</div>
|
||||
<button
|
||||
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded"
|
||||
onClick={() =>
|
||||
setCreateConnectorToggle(
|
||||
(createConnectorToggle) => !createConnectorToggle
|
||||
)
|
||||
}
|
||||
>
|
||||
Create New
|
||||
</button>
|
||||
)}
|
||||
|
||||
{/* NOTE: connector will never be google_drive, since the ternary above will
|
||||
|
||||
@@ -1,111 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useEffect, useState } from "react";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Title from "@/components/ui/title";
|
||||
import { KeyIcon } from "@/components/icons/icons";
|
||||
import { getSourceMetadata, isValidSource } from "@/lib/sources";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { handleOAuthAuthorizationResponse } from "@/lib/oauth_utils";
|
||||
|
||||
export default function OAuthCallbackPage() {
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
const [statusMessage, setStatusMessage] = useState("Processing...");
|
||||
const [statusDetails, setStatusDetails] = useState(
|
||||
"Please wait while we complete the setup."
|
||||
);
|
||||
const [redirectUrl, setRedirectUrl] = useState<string | null>(null);
|
||||
const [isError, setIsError] = useState(false);
|
||||
const [pageTitle, setPageTitle] = useState(
|
||||
"Authorize with Third-Party service"
|
||||
);
|
||||
|
||||
// Extract query parameters
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
|
||||
const pathname = usePathname();
|
||||
const connector = pathname?.split("/")[3];
|
||||
|
||||
useEffect(() => {
|
||||
const handleOAuthCallback = async () => {
|
||||
if (!code || !state) {
|
||||
setStatusMessage("Improperly formed OAuth authorization request.");
|
||||
setStatusDetails(
|
||||
!code ? "Missing authorization code." : "Missing state parameter."
|
||||
);
|
||||
setIsError(true);
|
||||
return;
|
||||
}
|
||||
|
||||
if (!connector || !isValidSource(connector)) {
|
||||
setStatusMessage(
|
||||
`The specified connector source type ${connector} does not exist.`
|
||||
);
|
||||
setStatusDetails(`${connector} is not a valid source type.`);
|
||||
setIsError(true);
|
||||
return;
|
||||
}
|
||||
|
||||
const sourceMetadata = getSourceMetadata(connector as ValidSources);
|
||||
setPageTitle(`Authorize with ${sourceMetadata.displayName}`);
|
||||
|
||||
setStatusMessage("Processing...");
|
||||
setStatusDetails("Please wait while we complete authorization.");
|
||||
setIsError(false); // Ensure no error state during loading
|
||||
|
||||
try {
|
||||
const response = await handleOAuthAuthorizationResponse(code, state);
|
||||
|
||||
if (!response) {
|
||||
throw new Error("Empty response from OAuth server.");
|
||||
}
|
||||
|
||||
setStatusMessage("Success!");
|
||||
setStatusDetails(
|
||||
`Your authorization with ${sourceMetadata.displayName} completed successfully.`
|
||||
);
|
||||
setRedirectUrl(response.redirect_on_success); // Extract the redirect URL
|
||||
setIsError(false);
|
||||
} catch (error) {
|
||||
console.error("OAuth error:", error);
|
||||
setStatusMessage("Oops, something went wrong!");
|
||||
setStatusDetails(
|
||||
"An error occurred during the OAuth process. Please try again."
|
||||
);
|
||||
setIsError(true);
|
||||
}
|
||||
};
|
||||
|
||||
handleOAuthCallback();
|
||||
}, [code, state, connector]);
|
||||
|
||||
return (
|
||||
<div className="container mx-auto py-8">
|
||||
<AdminPageTitle title={pageTitle} icon={<KeyIcon size={32} />} />
|
||||
|
||||
<div className="flex flex-col items-center justify-center min-h-screen">
|
||||
<CardSection className="max-w-md">
|
||||
<h1 className="text-2xl font-bold mb-4">{statusMessage}</h1>
|
||||
<p className="text-text-500">{statusDetails}</p>
|
||||
{redirectUrl && !isError && (
|
||||
<div className="mt-4">
|
||||
<p className="text-sm">
|
||||
Click{" "}
|
||||
<a href={redirectUrl} className="text-blue-500 underline">
|
||||
here
|
||||
</a>{" "}
|
||||
to continue.
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</CardSection>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -2,7 +2,7 @@ import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { createConnector, runConnector } from "@/lib/connector";
|
||||
import { createCredential, linkCredential } from "@/lib/credential";
|
||||
import { FileConfig } from "@/lib/connectors/connectors";
|
||||
import { AccessType, ValidSources } from "@/lib/types";
|
||||
import { AccessType } from "@/lib/types";
|
||||
|
||||
export const submitFiles = async (
|
||||
selectedFiles: File[],
|
||||
@@ -34,7 +34,7 @@ export const submitFiles = async (
|
||||
|
||||
const [connectorErrorMsg, connector] = await createConnector<FileConfig>({
|
||||
name: "FileConnector-" + Date.now(),
|
||||
source: ValidSources.File,
|
||||
source: "file",
|
||||
input_type: "load_state",
|
||||
connector_specific_config: {
|
||||
file_locations: filePaths,
|
||||
@@ -60,7 +60,7 @@ export const submitFiles = async (
|
||||
const createCredentialResponse = await createCredential({
|
||||
credential_json: {},
|
||||
admin_public: true,
|
||||
source: ValidSources.File,
|
||||
source: "file",
|
||||
curator_public: true,
|
||||
groups: groups,
|
||||
name,
|
||||
|
||||
@@ -2,7 +2,6 @@ import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { createConnector, runConnector } from "@/lib/connector";
|
||||
import { linkCredential } from "@/lib/credential";
|
||||
import { GoogleSitesConfig } from "@/lib/connectors/connectors";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
|
||||
export const submitGoogleSite = async (
|
||||
selectedFiles: File[],
|
||||
@@ -39,7 +38,7 @@ export const submitGoogleSite = async (
|
||||
const [connectorErrorMsg, connector] =
|
||||
await createConnector<GoogleSitesConfig>({
|
||||
name: name ? name : `GoogleSitesConnector-${base_url}`,
|
||||
source: ValidSources.GoogleSites,
|
||||
source: "google_sites",
|
||||
input_type: "load_state",
|
||||
connector_specific_config: {
|
||||
base_url: base_url,
|
||||
|
||||
@@ -384,7 +384,7 @@ export function CCPairIndexingStatusTable({
|
||||
last_status: "success",
|
||||
connector: {
|
||||
name: "Sample File Connector",
|
||||
source: ValidSources.File,
|
||||
source: "file",
|
||||
input_type: "poll",
|
||||
connector_specific_config: {
|
||||
file_locations: ["/path/to/sample/file.txt"],
|
||||
@@ -401,7 +401,7 @@ export function CCPairIndexingStatusTable({
|
||||
credential: {
|
||||
id: 1,
|
||||
name: "Sample Credential",
|
||||
source: ValidSources.File,
|
||||
source: "file",
|
||||
user_id: "1",
|
||||
time_created: "2023-07-01T12:00:00Z",
|
||||
time_updated: "2023-07-01T12:00:00Z",
|
||||
|
||||
@@ -982,7 +982,7 @@ export function ChatPage({
|
||||
) {
|
||||
setDocumentSidebarToggled(false);
|
||||
}
|
||||
}, [chatSessionIdRef.current]);
|
||||
}, [selectedDocuments, filtersToggled]);
|
||||
|
||||
useEffect(() => {
|
||||
adjustDocumentSidebarWidth(); // Adjust the width on initial render
|
||||
@@ -1610,14 +1610,14 @@ export function ChatPage({
|
||||
}
|
||||
});
|
||||
};
|
||||
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
|
||||
const [showDocSidebar, setShowDocSidebar] = useState(false); // State to track if sidebar is open
|
||||
|
||||
// Used to maintain a "time out" for history sidebar so our existing refs can have time to process change
|
||||
const [untoggled, setUntoggled] = useState(false);
|
||||
const [loadingError, setLoadingError] = useState<string | null>(null);
|
||||
|
||||
const explicitlyUntoggle = () => {
|
||||
setShowHistorySidebar(false);
|
||||
setShowDocSidebar(false);
|
||||
|
||||
setUntoggled(true);
|
||||
setTimeout(() => {
|
||||
@@ -1636,7 +1636,7 @@ export function ChatPage({
|
||||
toggle();
|
||||
};
|
||||
const removeToggle = () => {
|
||||
setShowHistorySidebar(false);
|
||||
setShowDocSidebar(false);
|
||||
toggle(false);
|
||||
};
|
||||
|
||||
@@ -1646,8 +1646,8 @@ export function ChatPage({
|
||||
useSidebarVisibility({
|
||||
toggledSidebar,
|
||||
sidebarElementRef,
|
||||
showDocSidebar: showHistorySidebar,
|
||||
setShowDocSidebar: setShowHistorySidebar,
|
||||
showDocSidebar,
|
||||
setShowDocSidebar,
|
||||
setToggled: removeToggle,
|
||||
mobile: settings?.isMobile,
|
||||
});
|
||||
@@ -1923,7 +1923,6 @@ export function ChatPage({
|
||||
interface RegenerationRequest {
|
||||
messageId: number;
|
||||
parentMessage: Message;
|
||||
forceSearch?: boolean;
|
||||
}
|
||||
|
||||
function createRegenerator(regenerationRequest: RegenerationRequest) {
|
||||
@@ -1933,7 +1932,6 @@ export function ChatPage({
|
||||
modelOverRide,
|
||||
messageIdToResend: regenerationRequest.parentMessage.messageId,
|
||||
regenerationRequest,
|
||||
forceSearch: regenerationRequest.forceSearch,
|
||||
});
|
||||
};
|
||||
}
|
||||
@@ -2100,7 +2098,7 @@ export function ChatPage({
|
||||
duration-300
|
||||
ease-in-out
|
||||
${
|
||||
!untoggled && (showHistorySidebar || toggledSidebar)
|
||||
!untoggled && (showDocSidebar || toggledSidebar)
|
||||
? "opacity-100 w-[250px] translate-x-0"
|
||||
: "opacity-0 w-[200px] pointer-events-none -translate-x-10"
|
||||
}`}
|
||||
@@ -2114,7 +2112,7 @@ export function ChatPage({
|
||||
ref={innerSidebarElementRef}
|
||||
toggleSidebar={toggleSidebar}
|
||||
toggled={toggledSidebar && !settings?.isMobile}
|
||||
backgroundToggled={toggledSidebar || showHistorySidebar}
|
||||
backgroundToggled={toggledSidebar || showDocSidebar}
|
||||
existingChats={chatSessions}
|
||||
currentChatSession={selectedChatSession}
|
||||
folders={folders}
|
||||
@@ -2173,7 +2171,7 @@ export function ChatPage({
|
||||
)}
|
||||
|
||||
<BlurBackground
|
||||
visible={!untoggled && (showHistorySidebar || toggledSidebar)}
|
||||
visible={!untoggled && (showDocSidebar || toggledSidebar)}
|
||||
/>
|
||||
|
||||
<div
|
||||
@@ -2594,11 +2592,13 @@ export function ChatPage({
|
||||
previousMessage &&
|
||||
previousMessage.messageId
|
||||
) {
|
||||
createRegenerator({
|
||||
messageId: message.messageId,
|
||||
parentMessage: parentMessage!,
|
||||
onSubmit({
|
||||
messageIdToResend:
|
||||
previousMessage.messageId,
|
||||
forceSearch: true,
|
||||
})(llmOverrideManager.llmOverride);
|
||||
alternativeAssistantOverride:
|
||||
currentAlternativeAssistant,
|
||||
});
|
||||
} else {
|
||||
setPopup({
|
||||
type: "error",
|
||||
@@ -2831,7 +2831,7 @@ export function ChatPage({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<FixedLogo backgroundToggled={toggledSidebar || showHistorySidebar} />
|
||||
<FixedLogo backgroundToggled={toggledSidebar || showDocSidebar} />
|
||||
</div>
|
||||
{/* Right Sidebar - DocumentSidebar */}
|
||||
</div>
|
||||
|
||||
@@ -7,7 +7,6 @@ import { DocumentUpdatedAtBadge } from "@/components/search/DocumentUpdatedAtBad
|
||||
import { MetadataBadge } from "@/components/MetadataBadge";
|
||||
import { WebResultIcon } from "@/components/WebResultIcon";
|
||||
import { Dispatch, SetStateAction } from "react";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
|
||||
interface DocumentDisplayProps {
|
||||
closeSidebar: () => void;
|
||||
@@ -74,15 +73,19 @@ export function ChatDocumentDisplay({
|
||||
}
|
||||
|
||||
const handleViewFile = async () => {
|
||||
if (document.source_type == ValidSources.File && setPresentingDocument) {
|
||||
setPresentingDocument(document);
|
||||
} else if (document.link) {
|
||||
if (document.link) {
|
||||
window.open(document.link, "_blank");
|
||||
} else {
|
||||
closeSidebar();
|
||||
|
||||
setTimeout(async () => {
|
||||
setPresentingDocument(document);
|
||||
}, 100);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className={`opacity-100 ${modal ? "w-[90vw]" : "w-full"}`}>
|
||||
<div className={`opacity-100 ${modal ? "w-[90vw]" : "w-full"}`}>
|
||||
<div
|
||||
className={`flex relative flex-col gap-0.5 rounded-xl mx-2 my-1 ${
|
||||
isSelected ? "bg-gray-200" : "hover:bg-background-125"
|
||||
|
||||
@@ -12,8 +12,7 @@ export function DocumentSelector({
|
||||
}) {
|
||||
const [popupDisabled, setPopupDisabled] = useState(false);
|
||||
|
||||
function onClick(e: React.MouseEvent<HTMLInputElement>) {
|
||||
e.stopPropagation();
|
||||
function onClick() {
|
||||
if (!isDisabled) {
|
||||
setPopupDisabled(true);
|
||||
handleSelect();
|
||||
|
||||
@@ -468,13 +468,7 @@ export const AIMessage = ({
|
||||
docs
|
||||
.slice(0, 2)
|
||||
.map((doc, ind) => (
|
||||
<SourceCard
|
||||
doc={doc}
|
||||
key={ind}
|
||||
setPresentingDocument={
|
||||
setPresentingDocument
|
||||
}
|
||||
/>
|
||||
<SourceCard doc={doc} key={ind} />
|
||||
))}
|
||||
<SeeMoreBlock
|
||||
documentSelectionToggled={
|
||||
|
||||
@@ -1,6 +1,26 @@
|
||||
import { EmphasizedClickable } from "@/components/BasicClickable";
|
||||
import { FiBook } from "react-icons/fi";
|
||||
|
||||
function ForceSearchButton({
|
||||
messageId,
|
||||
handleShowRetrieved,
|
||||
}: {
|
||||
messageId: number | null;
|
||||
isCurrentlyShowingRetrieved: boolean;
|
||||
handleShowRetrieved: (messageId: number | null) => void;
|
||||
}) {
|
||||
return (
|
||||
<div
|
||||
className="ml-auto my-auto"
|
||||
onClick={() => handleShowRetrieved(messageId)}
|
||||
>
|
||||
<EmphasizedClickable>
|
||||
<div className="w-24 text-xs">Force Search</div>
|
||||
</EmphasizedClickable>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export function SkippedSearch({
|
||||
handleForceSearch,
|
||||
}: {
|
||||
|
||||
@@ -97,69 +97,3 @@ export function getDatesList(startDate: Date): string[] {
|
||||
|
||||
return datesList;
|
||||
}
|
||||
|
||||
export interface PersonaMessageAnalytics {
|
||||
total_messages: number;
|
||||
date: string;
|
||||
persona_id: number;
|
||||
}
|
||||
|
||||
export interface PersonaSnapshot {
|
||||
id: number;
|
||||
name: string;
|
||||
description: string;
|
||||
is_visible: boolean;
|
||||
is_public: boolean;
|
||||
}
|
||||
|
||||
export const usePersonaMessages = (
|
||||
personaId: number | undefined,
|
||||
timeRange: DateRangePickerValue
|
||||
) => {
|
||||
const url = buildApiPath(`/api/analytics/admin/persona/messages`, {
|
||||
persona_id: personaId?.toString(),
|
||||
start: convertDateToStartOfDay(timeRange.from)?.toISOString(),
|
||||
end: convertDateToEndOfDay(timeRange.to)?.toISOString(),
|
||||
});
|
||||
|
||||
const { data, error, isLoading } = useSWR<PersonaMessageAnalytics[]>(
|
||||
personaId !== undefined ? url : null,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
return {
|
||||
data,
|
||||
error,
|
||||
isLoading,
|
||||
refreshPersonaMessages: () => mutate(url),
|
||||
};
|
||||
};
|
||||
|
||||
export interface PersonaUniqueUserAnalytics {
|
||||
unique_users: number;
|
||||
date: string;
|
||||
persona_id: number;
|
||||
}
|
||||
|
||||
export const usePersonaUniqueUsers = (
|
||||
personaId: number | undefined,
|
||||
timeRange: DateRangePickerValue
|
||||
) => {
|
||||
const url = buildApiPath(`/api/analytics/admin/persona/unique-users`, {
|
||||
persona_id: personaId?.toString(),
|
||||
start: convertDateToStartOfDay(timeRange.from)?.toISOString(),
|
||||
end: convertDateToEndOfDay(timeRange.to)?.toISOString(),
|
||||
});
|
||||
|
||||
const { data, error, isLoading } = useSWR<PersonaUniqueUserAnalytics[]>(
|
||||
personaId !== undefined ? url : null,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
return {
|
||||
data,
|
||||
error,
|
||||
isLoading,
|
||||
refreshPersonaUniqueUsers: () => mutate(url),
|
||||
};
|
||||
};
|
||||
|
||||
@@ -1,231 +0,0 @@
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { X, Search } from "lucide-react";
|
||||
import {
|
||||
getDatesList,
|
||||
usePersonaMessages,
|
||||
usePersonaUniqueUsers,
|
||||
} from "../lib";
|
||||
import { useAssistants } from "@/components/context/AssistantsContext";
|
||||
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
|
||||
import Text from "@/components/ui/text";
|
||||
import Title from "@/components/ui/title";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { AreaChartDisplay } from "@/components/ui/areaChart";
|
||||
import {
|
||||
Select,
|
||||
SelectContent,
|
||||
SelectItem,
|
||||
SelectTrigger,
|
||||
SelectValue,
|
||||
} from "@/components/ui/select";
|
||||
import { useState, useMemo, useEffect } from "react";
|
||||
|
||||
export function PersonaMessagesChart({
|
||||
timeRange,
|
||||
}: {
|
||||
timeRange: DateRangePickerValue;
|
||||
}) {
|
||||
const [selectedPersonaId, setSelectedPersonaId] = useState<
|
||||
number | undefined
|
||||
>(undefined);
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const [highlightedIndex, setHighlightedIndex] = useState(-1);
|
||||
const { allAssistants: personaList } = useAssistants();
|
||||
|
||||
const {
|
||||
data: personaMessagesData,
|
||||
isLoading: isPersonaMessagesLoading,
|
||||
error: personaMessagesError,
|
||||
} = usePersonaMessages(selectedPersonaId, timeRange);
|
||||
|
||||
const {
|
||||
data: personaUniqueUsersData,
|
||||
isLoading: isPersonaUniqueUsersLoading,
|
||||
error: personaUniqueUsersError,
|
||||
} = usePersonaUniqueUsers(selectedPersonaId, timeRange);
|
||||
|
||||
const isLoading = isPersonaMessagesLoading || isPersonaUniqueUsersLoading;
|
||||
const hasError = personaMessagesError || personaUniqueUsersError;
|
||||
|
||||
const filteredPersonaList = useMemo(() => {
|
||||
if (!personaList) return [];
|
||||
return personaList.filter((persona) =>
|
||||
persona.name.toLowerCase().includes(searchQuery.toLowerCase())
|
||||
);
|
||||
}, [personaList, searchQuery]);
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
e.stopPropagation();
|
||||
|
||||
switch (e.key) {
|
||||
case "ArrowDown":
|
||||
e.preventDefault();
|
||||
setHighlightedIndex((prev) =>
|
||||
prev < filteredPersonaList.length - 1 ? prev + 1 : prev
|
||||
);
|
||||
break;
|
||||
case "ArrowUp":
|
||||
e.preventDefault();
|
||||
setHighlightedIndex((prev) => (prev > 0 ? prev - 1 : prev));
|
||||
break;
|
||||
case "Enter":
|
||||
if (
|
||||
highlightedIndex >= 0 &&
|
||||
highlightedIndex < filteredPersonaList.length
|
||||
) {
|
||||
setSelectedPersonaId(filteredPersonaList[highlightedIndex].id);
|
||||
setSearchQuery("");
|
||||
setHighlightedIndex(-1);
|
||||
}
|
||||
break;
|
||||
case "Escape":
|
||||
setSearchQuery("");
|
||||
setHighlightedIndex(-1);
|
||||
break;
|
||||
}
|
||||
};
|
||||
|
||||
// Reset highlight when search query changes
|
||||
useEffect(() => {
|
||||
setHighlightedIndex(-1);
|
||||
}, [searchQuery]);
|
||||
|
||||
const chartData = useMemo(() => {
|
||||
if (
|
||||
!personaMessagesData?.length ||
|
||||
!personaUniqueUsersData?.length ||
|
||||
selectedPersonaId === undefined
|
||||
) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const initialDate =
|
||||
timeRange.from ||
|
||||
new Date(
|
||||
Math.min(
|
||||
...personaMessagesData.map((entry) => new Date(entry.date).getTime())
|
||||
)
|
||||
);
|
||||
const dateRange = getDatesList(initialDate);
|
||||
|
||||
// Create maps for messages and unique users data
|
||||
const messagesMap = new Map(
|
||||
personaMessagesData.map((entry) => [entry.date, entry])
|
||||
);
|
||||
const uniqueUsersMap = new Map(
|
||||
personaUniqueUsersData.map((entry) => [entry.date, entry])
|
||||
);
|
||||
|
||||
return dateRange.map((dateStr) => {
|
||||
const messageData = messagesMap.get(dateStr);
|
||||
const uniqueUserData = uniqueUsersMap.get(dateStr);
|
||||
return {
|
||||
Day: dateStr,
|
||||
Messages: messageData?.total_messages || 0,
|
||||
"Unique Users": uniqueUserData?.unique_users || 0,
|
||||
};
|
||||
});
|
||||
}, [
|
||||
personaMessagesData,
|
||||
personaUniqueUsersData,
|
||||
timeRange.from,
|
||||
selectedPersonaId,
|
||||
]);
|
||||
|
||||
let content;
|
||||
if (isLoading) {
|
||||
content = (
|
||||
<div className="h-80 flex flex-col">
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
);
|
||||
} else if (!personaList || hasError) {
|
||||
content = (
|
||||
<div className="h-80 text-red-600 text-bold flex flex-col">
|
||||
<p className="m-auto">Failed to fetch data...</p>
|
||||
</div>
|
||||
);
|
||||
} else if (selectedPersonaId === undefined) {
|
||||
content = (
|
||||
<div className="h-80 text-gray-500 flex flex-col">
|
||||
<p className="m-auto">Select a persona to view analytics</p>
|
||||
</div>
|
||||
);
|
||||
} else if (!personaMessagesData?.length) {
|
||||
content = (
|
||||
<div className="h-80 text-gray-500 flex flex-col">
|
||||
<p className="m-auto">
|
||||
No data found for selected persona in the selected time range
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
} else if (chartData) {
|
||||
content = (
|
||||
<AreaChartDisplay
|
||||
className="mt-4"
|
||||
data={chartData}
|
||||
categories={["Messages", "Unique Users"]}
|
||||
index="Day"
|
||||
colors={["indigo", "fuchsia"]}
|
||||
yAxisWidth={60}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const selectedPersona = personaList?.find((p) => p.id === selectedPersonaId);
|
||||
|
||||
return (
|
||||
<CardSection className="mt-8">
|
||||
<Title>Persona Analytics</Title>
|
||||
<div className="flex flex-col gap-4">
|
||||
<Text>Messages and unique users per day for selected persona</Text>
|
||||
<div className="flex items-center gap-4">
|
||||
<Select
|
||||
value={selectedPersonaId?.toString() ?? ""}
|
||||
onValueChange={(value) => {
|
||||
setSelectedPersonaId(parseInt(value));
|
||||
}}
|
||||
>
|
||||
<SelectTrigger className="flex w-full max-w-xs">
|
||||
<SelectValue placeholder="Select a persona to display" />
|
||||
</SelectTrigger>
|
||||
<SelectContent>
|
||||
<div className="flex items-center px-2 pb-2 sticky top-0 bg-background border-b">
|
||||
<Search className="h-4 w-4 mr-2 shrink-0 opacity-50" />
|
||||
<input
|
||||
className="flex h-8 w-full rounded-sm bg-transparent py-3 text-sm outline-none placeholder:text-muted-foreground disabled:cursor-not-allowed disabled:opacity-50"
|
||||
placeholder="Search personas..."
|
||||
value={searchQuery}
|
||||
onChange={(e) => setSearchQuery(e.target.value)}
|
||||
onClick={(e) => e.stopPropagation()}
|
||||
onMouseDown={(e) => e.stopPropagation()}
|
||||
onKeyDown={handleKeyDown}
|
||||
/>
|
||||
{searchQuery && (
|
||||
<X
|
||||
className="h-4 w-4 shrink-0 opacity-50 cursor-pointer hover:opacity-100"
|
||||
onClick={() => {
|
||||
setSearchQuery("");
|
||||
setHighlightedIndex(-1);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
{filteredPersonaList.map((persona, index) => (
|
||||
<SelectItem
|
||||
key={persona.id}
|
||||
value={persona.id.toString()}
|
||||
className={`${highlightedIndex === index ? "hover" : ""}`}
|
||||
onMouseEnter={() => setHighlightedIndex(index)}
|
||||
>
|
||||
{persona.name}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</div>
|
||||
</div>
|
||||
{content}
|
||||
</CardSection>
|
||||
);
|
||||
}
|
||||
@@ -4,7 +4,6 @@ import { DateRangeSelector } from "../DateRangeSelector";
|
||||
import { DanswerBotChart } from "./DanswerBotChart";
|
||||
import { FeedbackChart } from "./FeedbackChart";
|
||||
import { QueryPerformanceChart } from "./QueryPerformanceChart";
|
||||
import { PersonaMessagesChart } from "./PersonaMessagesChart";
|
||||
import { useTimeRange } from "../lib";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { FiActivity } from "react-icons/fi";
|
||||
@@ -27,7 +26,6 @@ export default function AnalyticsPage() {
|
||||
<QueryPerformanceChart timeRange={timeRange} />
|
||||
<FeedbackChart timeRange={timeRange} />
|
||||
<DanswerBotChart timeRange={timeRange} />
|
||||
<PersonaMessagesChart timeRange={timeRange} />
|
||||
<Separator />
|
||||
<UsageReports />
|
||||
</main>
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
import { redirect } from "next/navigation";
|
||||
|
||||
export default function NotFound() {
|
||||
redirect("/chat");
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import faviconFetch from "favicon-fetch";
|
||||
import { SourceIcon } from "./SourceIcon";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
|
||||
const CACHE_DURATION = 24 * 60 * 60 * 1000;
|
||||
|
||||
@@ -46,7 +45,7 @@ export function SearchResultIcon({ url }: { url: string }) {
|
||||
}, [url]);
|
||||
|
||||
if (!faviconUrl) {
|
||||
return <SourceIcon sourceType={ValidSources.Web} iconSize={18} />;
|
||||
return <SourceIcon sourceType="web" iconSize={18} />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import { SourceIcon } from "./SourceIcon";
|
||||
|
||||
export function WebResultIcon({ url }: { url: string }) {
|
||||
@@ -12,6 +11,6 @@ export function WebResultIcon({ url }: { url: string }) {
|
||||
width={18}
|
||||
/>
|
||||
) : (
|
||||
<SourceIcon sourceType={ValidSources.Web} iconSize={18} />
|
||||
<SourceIcon sourceType="web" iconSize={18} />
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import { DefaultDropdown } from "@/components/Dropdown";
|
||||
import {
|
||||
AccessType,
|
||||
ValidAutoSyncSource,
|
||||
ValidAutoSyncSources,
|
||||
ConfigurableSources,
|
||||
validAutoSyncSources,
|
||||
} from "@/lib/types";
|
||||
@@ -13,8 +13,8 @@ import { useEffect } from "react";
|
||||
|
||||
function isValidAutoSyncSource(
|
||||
value: ConfigurableSources
|
||||
): value is ValidAutoSyncSource {
|
||||
return validAutoSyncSources.includes(value as ValidAutoSyncSource);
|
||||
): value is ValidAutoSyncSources {
|
||||
return validAutoSyncSources.includes(value as ValidAutoSyncSources);
|
||||
}
|
||||
|
||||
export function AccessTypeForm({
|
||||
@@ -92,7 +92,9 @@ export function AccessTypeForm({
|
||||
/>
|
||||
|
||||
{access_type.value === "sync" && isAutoSyncSupported && (
|
||||
<AutoSyncOptions connectorType={connector as ValidAutoSyncSource} />
|
||||
<AutoSyncOptions
|
||||
connectorType={connector as ValidAutoSyncSources}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user