mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 01:05:46 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56fd40e606 | ||
|
|
415d644200 |
@@ -1,57 +0,0 @@
|
||||
"""delete_input_prompts
|
||||
|
||||
Revision ID: bf7a81109301
|
||||
Revises: f7a894b06d02
|
||||
Create Date: 2024-12-09 12:00:49.884228
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import fastapi_users_db_sqlalchemy
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "bf7a81109301"
|
||||
down_revision = "f7a894b06d02"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.drop_table("inputprompt__user")
|
||||
op.drop_table("inputprompt")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"inputprompt",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("prompt", sa.String(), nullable=False),
|
||||
sa.Column("content", sa.String(), nullable=False),
|
||||
sa.Column("active", sa.Boolean(), nullable=False),
|
||||
sa.Column("is_public", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"inputprompt__user",
|
||||
sa.Column("input_prompt_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["input_prompt_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["inputprompt.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("input_prompt_id", "user_id"),
|
||||
)
|
||||
@@ -1,4 +1,3 @@
|
||||
import hashlib
|
||||
import secrets
|
||||
import uuid
|
||||
from urllib.parse import quote
|
||||
@@ -19,8 +18,7 @@ _API_KEY_HEADER_NAME = "Authorization"
|
||||
# organizations like the Internet Engineering Task Force (IETF).
|
||||
_API_KEY_HEADER_ALTERNATIVE_NAME = "X-Danswer-Authorization"
|
||||
_BEARER_PREFIX = "Bearer "
|
||||
_API_KEY_PREFIX = "on_"
|
||||
_DEPRECATED_API_KEY_PREFIX = "dn_"
|
||||
_API_KEY_PREFIX = "dn_"
|
||||
_API_KEY_LEN = 192
|
||||
|
||||
|
||||
@@ -54,9 +52,7 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
|
||||
api_key = raw_api_key_header[len(_BEARER_PREFIX) :].strip()
|
||||
|
||||
if not api_key.startswith(_API_KEY_PREFIX) and not api_key.startswith(
|
||||
_DEPRECATED_API_KEY_PREFIX
|
||||
):
|
||||
if not api_key.startswith(_API_KEY_PREFIX):
|
||||
return None
|
||||
|
||||
parts = api_key[len(_API_KEY_PREFIX) :].split(".", 1)
|
||||
@@ -67,19 +63,10 @@ def extract_tenant_from_api_key_header(request: Request) -> str | None:
|
||||
return unquote(tenant_id) if tenant_id else None
|
||||
|
||||
|
||||
def _deprecated_hash_api_key(api_key: str) -> str:
|
||||
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
|
||||
|
||||
|
||||
def hash_api_key(api_key: str) -> str:
|
||||
# NOTE: no salt is needed, as the API key is randomly generated
|
||||
# and overlaps are impossible
|
||||
if api_key.startswith(_API_KEY_PREFIX):
|
||||
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
|
||||
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
|
||||
return _deprecated_hash_api_key(api_key)
|
||||
else:
|
||||
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
|
||||
return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS)
|
||||
|
||||
|
||||
def build_displayable_api_key(api_key: str) -> str:
|
||||
|
||||
@@ -9,6 +9,7 @@ from danswer.utils.special_types import JSON_ro
|
||||
def get_invited_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
|
||||
return cast(list, store.load(KV_USER_STORE_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
@@ -131,7 +131,7 @@ def get_display_email(email: str | None, space_less: bool = False) -> str:
|
||||
|
||||
|
||||
def user_needs_to_be_verified() -> bool:
|
||||
if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD:
|
||||
if AUTH_TYPE == AuthType.BASIC:
|
||||
return REQUIRE_EMAIL_VERIFICATION
|
||||
|
||||
# For other auth types, if the user is authenticated it's assumed that
|
||||
|
||||
@@ -598,7 +598,7 @@ def connector_indexing_proxy_task(
|
||||
db_session,
|
||||
"Connector termination signal detected",
|
||||
)
|
||||
except Exception:
|
||||
finally:
|
||||
# if the DB exceptions, we'll just get an unfriendly failure message
|
||||
# in the UI instead of the cancellation message
|
||||
logger.exception(
|
||||
|
||||
@@ -680,28 +680,17 @@ def monitor_ccpair_indexing_taskset(
|
||||
)
|
||||
task_logger.warning(msg)
|
||||
|
||||
try:
|
||||
index_attempt = get_index_attempt(
|
||||
db_session, payload.index_attempt_id
|
||||
)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"monitor_ccpair_indexing_taskset - transient exception marking index attempt as failed: "
|
||||
f"attempt={payload.index_attempt_id} "
|
||||
f"tenant={tenant_id} "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"search_settings={search_settings_id}"
|
||||
)
|
||||
index_attempt = get_index_attempt(db_session, payload.index_attempt_id)
|
||||
if index_attempt:
|
||||
if (
|
||||
index_attempt.status != IndexingStatus.CANCELED
|
||||
and index_attempt.status != IndexingStatus.FAILED
|
||||
):
|
||||
mark_attempt_failed(
|
||||
index_attempt_id=payload.index_attempt_id,
|
||||
db_session=db_session,
|
||||
failure_reason=msg,
|
||||
)
|
||||
|
||||
redis_connector_index.reset()
|
||||
return
|
||||
|
||||
@@ -206,9 +206,7 @@ class Answer:
|
||||
# + figure out what the next LLM call should be
|
||||
tool_call_handler = ToolResponseHandler(current_llm_call.tools)
|
||||
|
||||
search_result, displayed_search_results_map = SearchTool.get_search_result(
|
||||
current_llm_call
|
||||
) or ([], {})
|
||||
search_result = SearchTool.get_search_result(current_llm_call) or []
|
||||
|
||||
# Quotes are no longer supported
|
||||
# answer_handler: AnswerResponseHandler
|
||||
@@ -226,7 +224,6 @@ class Answer:
|
||||
answer_handler = CitationResponseHandler(
|
||||
context_docs=search_result,
|
||||
doc_id_to_rank_map=map_document_id_order(search_result),
|
||||
display_doc_order_dict=displayed_search_results_map,
|
||||
)
|
||||
|
||||
response_handler_manager = LLMResponseHandlerManager(
|
||||
|
||||
@@ -35,18 +35,13 @@ class DummyAnswerResponseHandler(AnswerResponseHandler):
|
||||
|
||||
class CitationResponseHandler(AnswerResponseHandler):
|
||||
def __init__(
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
self, context_docs: list[LlmDoc], doc_id_to_rank_map: DocumentIdOrderMapping
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.display_doc_order_dict = display_doc_order_dict
|
||||
self.citation_processor = CitationProcessor(
|
||||
context_docs=self.context_docs,
|
||||
doc_id_to_rank_map=self.doc_id_to_rank_map,
|
||||
display_doc_order_dict=self.display_doc_order_dict,
|
||||
)
|
||||
self.processed_text = ""
|
||||
self.citations: list[CitationInfo] = []
|
||||
|
||||
@@ -22,16 +22,12 @@ class CitationProcessor:
|
||||
self,
|
||||
context_docs: list[LlmDoc],
|
||||
doc_id_to_rank_map: DocumentIdOrderMapping,
|
||||
display_doc_order_dict: dict[str, int],
|
||||
stop_stream: str | None = STOP_STREAM_PAT,
|
||||
):
|
||||
self.context_docs = context_docs
|
||||
self.doc_id_to_rank_map = doc_id_to_rank_map
|
||||
self.stop_stream = stop_stream
|
||||
self.order_mapping = doc_id_to_rank_map.order_mapping
|
||||
self.display_doc_order_dict = (
|
||||
display_doc_order_dict # original order of docs to displayed to user
|
||||
)
|
||||
self.llm_out = ""
|
||||
self.max_citation_num = len(context_docs)
|
||||
self.citation_order: list[int] = []
|
||||
@@ -102,18 +98,6 @@ class CitationProcessor:
|
||||
self.citation_order.index(real_citation_num) + 1
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_doc_order_dict:
|
||||
displayed_citation_num = self.display_doc_order_dict[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = real_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if target_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
@@ -134,7 +118,6 @@ class CitationProcessor:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
@@ -156,7 +139,6 @@ class CitationProcessor:
|
||||
if target_citation_num not in self.cited_inds:
|
||||
self.cited_inds.add(target_citation_num)
|
||||
yield CitationInfo(
|
||||
# stay with the original for now (order of LLM cites)
|
||||
citation_num=target_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
@@ -166,8 +148,7 @@ class CitationProcessor:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]({link})"
|
||||
+ f"[[{target_citation_num}]]({link})"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
@@ -175,8 +156,7 @@ class CitationProcessor:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
# + f"[[{target_citation_num}]]()"
|
||||
+ f"[[{target_citation_num}]]()"
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
@@ -348,12 +348,6 @@ GITLAB_CONNECTOR_INCLUDE_CODE_FILES = (
|
||||
os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Egnyte specific configs
|
||||
EGNYTE_LOCALHOST_OVERRIDE = os.getenv("EGNYTE_LOCALHOST_OVERRIDE")
|
||||
EGNYTE_BASE_DOMAIN = os.getenv("EGNYTE_DOMAIN")
|
||||
EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID")
|
||||
EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET")
|
||||
|
||||
DASK_JOB_CLIENT_ENABLED = (
|
||||
os.environ.get("DASK_JOB_CLIENT_ENABLED", "").lower() == "true"
|
||||
)
|
||||
@@ -417,28 +411,21 @@ LARGE_CHUNK_RATIO = 4
|
||||
# We don't want the metadata to overwhelm the actual contents of the chunk
|
||||
SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true"
|
||||
# Timeout to wait for job's last update before killing it, in hours
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(
|
||||
os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT") or 3
|
||||
)
|
||||
CLEANUP_INDEXING_JOBS_TIMEOUT = int(os.environ.get("CLEANUP_INDEXING_JOBS_TIMEOUT", 3))
|
||||
|
||||
# The indexer will warn in the logs whenver a document exceeds this threshold (in bytes)
|
||||
INDEXING_SIZE_WARNING_THRESHOLD = int(
|
||||
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024
|
||||
os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD", 100 * 1024 * 1024)
|
||||
)
|
||||
|
||||
# during indexing, will log verbose memory diff stats every x batches and at the end.
|
||||
# 0 disables this behavior and is the default.
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0)
|
||||
INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL", 0))
|
||||
|
||||
# During an indexing attempt, specifies the number of batches which are allowed to
|
||||
# exception without aborting the attempt.
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT") or 0)
|
||||
INDEXING_EXCEPTION_LIMIT = int(os.environ.get("INDEXING_EXCEPTION_LIMIT", 0))
|
||||
|
||||
# Maximum file size in a document to be indexed
|
||||
MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000)
|
||||
MAX_FILE_SIZE_BYTES = int(
|
||||
os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024
|
||||
) # 2GB in bytes
|
||||
|
||||
#####
|
||||
# Miscellaneous
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
|
||||
PROMPTS_YAML = "./danswer/seeding/prompts.yaml"
|
||||
PERSONAS_YAML = "./danswer/seeding/personas.yaml"
|
||||
INPUT_PROMPT_YAML = "./danswer/seeding/input_prompts.yaml"
|
||||
|
||||
NUM_RETURNED_HITS = 50
|
||||
# Used for LLM filtering and reranking
|
||||
|
||||
@@ -132,7 +132,6 @@ class DocumentSource(str, Enum):
|
||||
NOT_APPLICABLE = "not_applicable"
|
||||
FRESHDESK = "freshdesk"
|
||||
FIREFLIES = "fireflies"
|
||||
EGNYTE = "egnyte"
|
||||
|
||||
|
||||
DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE]
|
||||
|
||||
@@ -1,384 +0,0 @@
|
||||
import io
|
||||
import os
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from logging import Logger
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import IO
|
||||
|
||||
import requests
|
||||
from retry import retry
|
||||
|
||||
from danswer.configs.app_configs import EGNYTE_BASE_DOMAIN
|
||||
from danswer.configs.app_configs import EGNYTE_CLIENT_ID
|
||||
from danswer.configs.app_configs import EGNYTE_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import EGNYTE_LOCALHOST_OVERRIDE
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import GenerateDocumentsOutput
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import OAuthConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import ConnectorMissingCredentialError
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import is_valid_file_ext
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1"
|
||||
_EGNYTE_APP_BASE = "https://{domain}.egnyte.com"
|
||||
_TIMEOUT = 60
|
||||
|
||||
|
||||
def _request_with_retries(
|
||||
method: str,
|
||||
url: str,
|
||||
data: dict[str, Any] | None = None,
|
||||
headers: dict[str, Any] | None = None,
|
||||
params: dict[str, Any] | None = None,
|
||||
timeout: int = _TIMEOUT,
|
||||
stream: bool = False,
|
||||
tries: int = 8,
|
||||
delay: float = 1,
|
||||
backoff: float = 2,
|
||||
) -> requests.Response:
|
||||
@retry(tries=tries, delay=delay, backoff=backoff, logger=cast(Logger, logger))
|
||||
def _make_request() -> requests.Response:
|
||||
response = requests.request(
|
||||
method,
|
||||
url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
params=params,
|
||||
timeout=timeout,
|
||||
stream=stream,
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code != 403:
|
||||
logger.exception(
|
||||
f"Failed to call Egnyte API.\n"
|
||||
f"URL: {url}\n"
|
||||
f"Headers: {headers}\n"
|
||||
f"Data: {data}\n"
|
||||
f"Params: {params}"
|
||||
)
|
||||
raise e
|
||||
return response
|
||||
|
||||
return _make_request()
|
||||
|
||||
|
||||
def _parse_last_modified(last_modified: str) -> datetime:
|
||||
return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace(
|
||||
tzinfo=timezone.utc
|
||||
)
|
||||
|
||||
|
||||
def _process_egnyte_file(
|
||||
file_metadata: dict[str, Any],
|
||||
file_content: IO,
|
||||
base_url: str,
|
||||
folder_path: str | None = None,
|
||||
) -> Document | None:
|
||||
"""Process an Egnyte file into a Document object
|
||||
|
||||
Args:
|
||||
file_data: The file data from Egnyte API
|
||||
file_content: The raw content of the file in bytes
|
||||
base_url: The base URL for the Egnyte instance
|
||||
folder_path: Optional folder path to filter results
|
||||
"""
|
||||
# Skip if file path doesn't match folder path filter
|
||||
if folder_path and not file_metadata["path"].startswith(folder_path):
|
||||
raise ValueError(
|
||||
f"File path {file_metadata['path']} does not match folder path {folder_path}"
|
||||
)
|
||||
|
||||
file_name = file_metadata["name"]
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return None
|
||||
|
||||
# Extract text content based on file type
|
||||
if is_text_file_extension(file_name):
|
||||
encoding = detect_encoding(file_content)
|
||||
file_content_raw, file_metadata = read_text_file(
|
||||
file_content, encoding=encoding, ignore_danswer_metadata=False
|
||||
)
|
||||
else:
|
||||
file_content_raw = extract_file_text(
|
||||
file=file_content,
|
||||
file_name=file_name,
|
||||
break_on_unprocessable=True,
|
||||
)
|
||||
|
||||
# Build the web URL for the file
|
||||
web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}"
|
||||
|
||||
# Create document metadata
|
||||
metadata: dict[str, str | list[str]] = {
|
||||
"file_path": file_metadata["path"],
|
||||
"last_modified": file_metadata.get("last_modified", ""),
|
||||
}
|
||||
|
||||
# Add lock info if present
|
||||
if lock_info := file_metadata.get("lock_info"):
|
||||
metadata[
|
||||
"lock_owner"
|
||||
] = f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}"
|
||||
|
||||
# Create the document owners
|
||||
primary_owner = None
|
||||
if uploaded_by := file_metadata.get("uploaded_by"):
|
||||
primary_owner = BasicExpertInfo(
|
||||
email=uploaded_by, # Using username as email since that's what we have
|
||||
)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=f"egnyte-{file_metadata['entry_id']}",
|
||||
sections=[Section(text=file_content_raw.strip(), link=web_url)],
|
||||
source=DocumentSource.EGNYTE,
|
||||
semantic_identifier=file_name,
|
||||
metadata=metadata,
|
||||
doc_updated_at=(
|
||||
_parse_last_modified(file_metadata["last_modified"])
|
||||
if "last_modified" in file_metadata
|
||||
else None
|
||||
),
|
||||
primary_owners=[primary_owner] if primary_owner else None,
|
||||
)
|
||||
|
||||
|
||||
class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
def __init__(
|
||||
self,
|
||||
folder_path: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.domain = "" # will always be set in `load_credentials`
|
||||
self.folder_path = folder_path or "" # Root folder if not specified
|
||||
self.batch_size = batch_size
|
||||
self.access_token: str | None = None
|
||||
|
||||
@classmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
return DocumentSource.EGNYTE
|
||||
|
||||
@classmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
if EGNYTE_LOCALHOST_OVERRIDE:
|
||||
base_domain = EGNYTE_LOCALHOST_OVERRIDE
|
||||
|
||||
callback_uri = f"{base_domain.strip('/')}/connector/oauth/callback/egnyte"
|
||||
return (
|
||||
f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
f"?client_id={EGNYTE_CLIENT_ID}"
|
||||
f"&redirect_uri={callback_uri}"
|
||||
f"&scope=Egnyte.filesystem"
|
||||
f"&state={state}"
|
||||
f"&response_type=code"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
|
||||
if not EGNYTE_CLIENT_ID:
|
||||
raise ValueError("EGNYTE_CLIENT_ID environment variable must be set")
|
||||
if not EGNYTE_CLIENT_SECRET:
|
||||
raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set")
|
||||
if not EGNYTE_BASE_DOMAIN:
|
||||
raise ValueError("EGNYTE_DOMAIN environment variable must be set")
|
||||
|
||||
# Exchange code for token
|
||||
url = f"https://{EGNYTE_BASE_DOMAIN}.egnyte.com/puboauth/token"
|
||||
data = {
|
||||
"client_id": EGNYTE_CLIENT_ID,
|
||||
"client_secret": EGNYTE_CLIENT_SECRET,
|
||||
"code": code,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": f"{EGNYTE_LOCALHOST_OVERRIDE or ''}/connector/oauth/callback/egnyte",
|
||||
"scope": "Egnyte.filesystem",
|
||||
}
|
||||
headers = {"Content-Type": "application/x-www-form-urlencoded"}
|
||||
|
||||
response = _request_with_retries(
|
||||
method="POST",
|
||||
url=url,
|
||||
data=data,
|
||||
headers=headers,
|
||||
# try a lot faster since this is a realtime flow
|
||||
backoff=0,
|
||||
delay=0.1,
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to exchange code for token: {response.text}")
|
||||
|
||||
token_data = response.json()
|
||||
return {
|
||||
"domain": EGNYTE_BASE_DOMAIN,
|
||||
"access_token": token_data["access_token"],
|
||||
}
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.domain = credentials["domain"]
|
||||
self.access_token = credentials["access_token"]
|
||||
return None
|
||||
|
||||
def _get_files_list(
|
||||
self,
|
||||
path: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
if not self.access_token or not self.domain:
|
||||
raise ConnectorMissingCredentialError("Egnyte")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
|
||||
params: dict[str, Any] = {
|
||||
"list_content": True,
|
||||
}
|
||||
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{path or ''}"
|
||||
response = _request_with_retries(
|
||||
method="GET", url=url, headers=headers, params=params, timeout=_TIMEOUT
|
||||
)
|
||||
if not response.ok:
|
||||
raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}")
|
||||
|
||||
data = response.json()
|
||||
all_files: list[dict[str, Any]] = []
|
||||
|
||||
# Add files from current directory
|
||||
all_files.extend(data.get("files", []))
|
||||
|
||||
# Recursively traverse folders
|
||||
for item in data.get("folders", []):
|
||||
all_files.extend(self._get_files_list(item["path"]))
|
||||
|
||||
return all_files
|
||||
|
||||
def _filter_files(
|
||||
self,
|
||||
files: list[dict[str, Any]],
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
filtered_files = []
|
||||
for file in files:
|
||||
if file["is_folder"]:
|
||||
continue
|
||||
|
||||
file_modified = _parse_last_modified(file["last_modified"])
|
||||
if start_time and file_modified < start_time:
|
||||
continue
|
||||
if end_time and file_modified > end_time:
|
||||
continue
|
||||
|
||||
filtered_files.append(file)
|
||||
|
||||
return filtered_files
|
||||
|
||||
def _process_files(
|
||||
self,
|
||||
start_time: datetime | None = None,
|
||||
end_time: datetime | None = None,
|
||||
) -> Generator[list[Document], None, None]:
|
||||
files = self._get_files_list(self.folder_path)
|
||||
files = self._filter_files(files, start_time, end_time)
|
||||
|
||||
current_batch: list[Document] = []
|
||||
for file in files:
|
||||
try:
|
||||
# Set up request with streaming enabled
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.access_token}",
|
||||
}
|
||||
url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{file['path']}"
|
||||
response = _request_with_retries(
|
||||
method="GET",
|
||||
url=url,
|
||||
headers=headers,
|
||||
timeout=_TIMEOUT,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
logger.error(
|
||||
f"Failed to fetch file content: {file['path']} (status code: {response.status_code})"
|
||||
)
|
||||
continue
|
||||
|
||||
# Stream the response content into a BytesIO buffer
|
||||
buffer = io.BytesIO()
|
||||
for chunk in response.iter_content(chunk_size=8192):
|
||||
if chunk:
|
||||
buffer.write(chunk)
|
||||
|
||||
# Reset buffer's position to the start
|
||||
buffer.seek(0)
|
||||
|
||||
# Process the streamed file content
|
||||
doc = _process_egnyte_file(
|
||||
file_metadata=file,
|
||||
file_content=buffer,
|
||||
base_url=_EGNYTE_APP_BASE.format(domain=self.domain),
|
||||
folder_path=self.folder_path,
|
||||
)
|
||||
|
||||
if doc is not None:
|
||||
current_batch.append(doc)
|
||||
|
||||
if len(current_batch) >= self.batch_size:
|
||||
yield current_batch
|
||||
current_batch = []
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to process file {file['path']}")
|
||||
continue
|
||||
|
||||
if current_batch:
|
||||
yield current_batch
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
yield from self._process_files()
|
||||
|
||||
def poll_source(
|
||||
self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch
|
||||
) -> GenerateDocumentsOutput:
|
||||
start_time = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_time = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
|
||||
yield from self._process_files(start_time=start_time, end_time=end_time)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = EgnyteConnector()
|
||||
connector.load_credentials(
|
||||
{
|
||||
"domain": os.environ["EGNYTE_DOMAIN"],
|
||||
"access_token": os.environ["EGNYTE_ACCESS_TOKEN"],
|
||||
}
|
||||
)
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
@@ -15,7 +15,6 @@ from danswer.connectors.danswer_jira.connector import JiraConnector
|
||||
from danswer.connectors.discourse.connector import DiscourseConnector
|
||||
from danswer.connectors.document360.connector import Document360Connector
|
||||
from danswer.connectors.dropbox.connector import DropboxConnector
|
||||
from danswer.connectors.egnyte.connector import EgnyteConnector
|
||||
from danswer.connectors.file.connector import LocalFileConnector
|
||||
from danswer.connectors.fireflies.connector import FirefliesConnector
|
||||
from danswer.connectors.freshdesk.connector import FreshdeskConnector
|
||||
@@ -104,7 +103,6 @@ def identify_connector_class(
|
||||
DocumentSource.XENFORO: XenforoConnector,
|
||||
DocumentSource.FRESHDESK: FreshdeskConnector,
|
||||
DocumentSource.FIREFLIES: FirefliesConnector,
|
||||
DocumentSource.EGNYTE: EgnyteConnector,
|
||||
}
|
||||
connector_by_source = connector_map.get(source, {})
|
||||
|
||||
|
||||
@@ -17,11 +17,11 @@ from danswer.connectors.models import BasicExpertInfo
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.db.engine import get_session_with_tenant
|
||||
from danswer.file_processing.extract_file_text import check_file_ext_is_valid
|
||||
from danswer.file_processing.extract_file_text import detect_encoding
|
||||
from danswer.file_processing.extract_file_text import extract_file_text
|
||||
from danswer.file_processing.extract_file_text import get_file_ext
|
||||
from danswer.file_processing.extract_file_text import is_text_file_extension
|
||||
from danswer.file_processing.extract_file_text import is_valid_file_ext
|
||||
from danswer.file_processing.extract_file_text import load_files_from_zip
|
||||
from danswer.file_processing.extract_file_text import read_pdf_file
|
||||
from danswer.file_processing.extract_file_text import read_text_file
|
||||
@@ -50,7 +50,7 @@ def _read_files_and_metadata(
|
||||
file_content, ignore_dirs=True
|
||||
):
|
||||
yield os.path.join(directory_path, file_info.filename), file, metadata
|
||||
elif is_valid_file_ext(extension):
|
||||
elif check_file_ext_is_valid(extension):
|
||||
yield file_name, file_content, metadata
|
||||
else:
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
@@ -63,7 +63,7 @@ def _process_file(
|
||||
pdf_pass: str | None = None,
|
||||
) -> list[Document]:
|
||||
extension = get_file_ext(file_name)
|
||||
if not is_valid_file_ext(extension):
|
||||
if not check_file_ext_is_valid(extension):
|
||||
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
|
||||
return []
|
||||
|
||||
|
||||
@@ -4,13 +4,11 @@ from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
|
||||
from danswer.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from danswer.configs.app_configs import MAX_FILE_SIZE_BYTES
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.google_drive.doc_conversion import build_slim_document
|
||||
from danswer.connectors.google_drive.doc_conversion import (
|
||||
@@ -454,14 +452,12 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if isinstance(self.creds, ServiceAccountCredentials)
|
||||
else self._manage_oauth_retrieval
|
||||
)
|
||||
drive_files = retrieval_method(
|
||||
return retrieval_method(
|
||||
is_slim=is_slim,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
|
||||
return drive_files
|
||||
|
||||
def _extract_docs_from_google_drive(
|
||||
self,
|
||||
start: SecondsSinceUnixEpoch | None = None,
|
||||
@@ -477,15 +473,6 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
files_to_process = []
|
||||
# Gather the files into batches to be processed in parallel
|
||||
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
|
||||
if (
|
||||
file.get("size")
|
||||
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
|
||||
):
|
||||
logger.warning(
|
||||
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
|
||||
)
|
||||
continue
|
||||
|
||||
files_to_process.append(file)
|
||||
if len(files_to_process) >= LARGE_BATCH_SIZE:
|
||||
yield from _process_files_batch(
|
||||
|
||||
@@ -16,7 +16,7 @@ logger = setup_logger()
|
||||
|
||||
FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions, modifiedTime, webViewLink, "
|
||||
"shortcutDetails, owners(emailAddress), size)"
|
||||
"shortcutDetails, owners(emailAddress))"
|
||||
)
|
||||
SLIM_FILE_FIELDS = (
|
||||
"nextPageToken, files(mimeType, id, name, permissions(emailAddress, type), "
|
||||
|
||||
@@ -2,7 +2,6 @@ import abc
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import SlimDocument
|
||||
|
||||
@@ -65,23 +64,6 @@ class SlimConnector(BaseConnector):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class OAuthConnector(BaseConnector):
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_id(cls) -> DocumentSource:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_authorization_url(cls, base_domain: str, state: str) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def oauth_code_to_token(cls, code: str) -> dict[str, Any]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Event driven
|
||||
class EventConnector(BaseConnector):
|
||||
@abc.abstractmethod
|
||||
|
||||
@@ -132,6 +132,7 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
branchName
|
||||
customerTicketCount
|
||||
description
|
||||
descriptionData
|
||||
comments {
|
||||
nodes {
|
||||
url
|
||||
@@ -214,6 +215,5 @@ class LinearConnector(LoadConnector, PollConnector):
|
||||
if __name__ == "__main__":
|
||||
connector = LinearConnector()
|
||||
connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]})
|
||||
|
||||
document_batches = connector.load_from_state()
|
||||
print(next(document_batches))
|
||||
|
||||
@@ -171,9 +171,7 @@ def thread_to_doc(
|
||||
else first_message
|
||||
)
|
||||
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace(
|
||||
"\n", " "
|
||||
)
|
||||
doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}"
|
||||
|
||||
return Document(
|
||||
id=f"{channel_id}__{thread[0]['ts']}",
|
||||
|
||||
@@ -204,8 +204,7 @@ def _build_documents_blocks(
|
||||
continue
|
||||
seen_docs_identifiers.add(d.document_id)
|
||||
|
||||
# Strip newlines from the semantic identifier for Slackbot formatting
|
||||
doc_sem_id = d.semantic_identifier.replace("\n", " ")
|
||||
doc_sem_id = d.semantic_identifier
|
||||
if d.source_type == DocumentSource.SLACK.value:
|
||||
doc_sem_id = "#" + doc_sem_id
|
||||
|
||||
|
||||
@@ -373,9 +373,7 @@ def handle_regular_answer(
|
||||
respond_in_thread(
|
||||
client=client,
|
||||
channel=channel,
|
||||
receiver_ids=[message_info.sender]
|
||||
if message_info.is_bot_msg and message_info.sender
|
||||
else receiver_ids,
|
||||
receiver_ids=receiver_ids,
|
||||
text="Hello! Danswer has some results for you!",
|
||||
blocks=all_blocks,
|
||||
thread_ts=message_ts_to_respond_to,
|
||||
|
||||
@@ -11,7 +11,6 @@ from retry import retry
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.models.blocks import Block
|
||||
from slack_sdk.models.blocks import SectionBlock
|
||||
from slack_sdk.models.metadata import Metadata
|
||||
from slack_sdk.socket_mode import SocketModeClient
|
||||
|
||||
@@ -141,40 +140,6 @@ def remove_danswer_bot_tag(message_str: str, client: WebClient) -> str:
|
||||
return re.sub(rf"<@{bot_tag_id}>\s", "", message_str)
|
||||
|
||||
|
||||
def _check_for_url_in_block(block: Block) -> bool:
|
||||
"""
|
||||
Check if the block has a key that contains "url" in it
|
||||
"""
|
||||
block_dict = block.to_dict()
|
||||
|
||||
def check_dict_for_url(d: dict) -> bool:
|
||||
for key, value in d.items():
|
||||
if "url" in key.lower():
|
||||
return True
|
||||
if isinstance(value, dict):
|
||||
if check_dict_for_url(value):
|
||||
return True
|
||||
elif isinstance(value, list):
|
||||
for item in value:
|
||||
if isinstance(item, dict) and check_dict_for_url(item):
|
||||
return True
|
||||
return False
|
||||
|
||||
return check_dict_for_url(block_dict)
|
||||
|
||||
|
||||
def _build_error_block(error_message: str) -> Block:
|
||||
"""
|
||||
Build an error block to display in slack so that the user can see
|
||||
the error without completely breaking
|
||||
"""
|
||||
display_text = (
|
||||
"There was an error displaying all of the Onyx answers."
|
||||
f" Please let an admin or an onyx developer know. Error: {error_message}"
|
||||
)
|
||||
return SectionBlock(text=display_text)
|
||||
|
||||
|
||||
@retry(
|
||||
tries=DANSWER_BOT_NUM_RETRIES,
|
||||
delay=0.25,
|
||||
@@ -197,9 +162,24 @@ def respond_in_thread(
|
||||
message_ids: list[str] = []
|
||||
if not receiver_ids:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postMessage)
|
||||
try:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
for receiver in receiver_ids:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
@@ -207,68 +187,8 @@ def respond_in_thread(
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
|
||||
logger.warning("Trying again without blocks that have urls")
|
||||
|
||||
if not blocks:
|
||||
raise e
|
||||
|
||||
blocks_without_urls = [
|
||||
block for block in blocks if not _check_for_url_in_block(block)
|
||||
]
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
|
||||
message_ids.append(response["message_ts"])
|
||||
else:
|
||||
slack_call = make_slack_api_rate_limited(client.chat_postEphemeral)
|
||||
for receiver in receiver_ids:
|
||||
try:
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
|
||||
logger.warning("Trying again without blocks that have urls")
|
||||
|
||||
if not blocks:
|
||||
raise e
|
||||
|
||||
blocks_without_urls = [
|
||||
block for block in blocks if not _check_for_url_in_block(block)
|
||||
]
|
||||
blocks_without_urls.append(_build_error_block(str(e)))
|
||||
|
||||
# Try again wtihout blocks containing url
|
||||
response = slack_call(
|
||||
channel=channel,
|
||||
user=receiver,
|
||||
text=text,
|
||||
blocks=blocks_without_urls,
|
||||
thread_ts=thread_ts,
|
||||
metadata=metadata,
|
||||
unfurl_links=unfurl,
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
|
||||
if not response.get("ok"):
|
||||
raise RuntimeError(f"Failed to post message: {response}")
|
||||
message_ids.append(response["message_ts"])
|
||||
|
||||
return message_ids
|
||||
|
||||
@@ -20,6 +20,7 @@ from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.models import User
|
||||
from danswer.db.models import User__UserGroup
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.server.documents.models import CredentialDataUpdateRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -261,8 +262,7 @@ def _cleanup_credential__user_group_relationships__no_commit(
|
||||
|
||||
def alter_credential(
|
||||
credential_id: int,
|
||||
name: str,
|
||||
credential_json: dict[str, Any],
|
||||
credential_data: CredentialDataUpdateRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> Credential | None:
|
||||
@@ -272,13 +272,11 @@ def alter_credential(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.name = name
|
||||
credential.name = credential_data.name
|
||||
|
||||
# Assign a new dictionary to credential.credential_json
|
||||
credential.credential_json = {
|
||||
**credential.credential_json,
|
||||
**credential_json,
|
||||
}
|
||||
# Update only the keys present in credential_data.credential_json
|
||||
for key, value in credential_data.credential_json.items():
|
||||
credential.credential_json[key] = value
|
||||
|
||||
credential.user_id = user.id if user is not None else None
|
||||
db_session.commit()
|
||||
@@ -311,8 +309,8 @@ def update_credential_json(
|
||||
credential = fetch_credential_by_id(credential_id, user, db_session)
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.credential_json = credential_json
|
||||
|
||||
db_session.commit()
|
||||
return credential
|
||||
|
||||
|
||||
@@ -522,16 +522,12 @@ def expire_index_attempts(
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
not_started_query = (
|
||||
update(IndexAttempt)
|
||||
delete_query = (
|
||||
delete(IndexAttempt)
|
||||
.where(IndexAttempt.search_settings_id == search_settings_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
.values(
|
||||
status=IndexingStatus.CANCELED,
|
||||
error_msg="Canceled, likely due to model swap",
|
||||
)
|
||||
)
|
||||
db_session.execute(not_started_query)
|
||||
db_session.execute(delete_query)
|
||||
|
||||
update_query = (
|
||||
update(IndexAttempt)
|
||||
@@ -553,14 +549,9 @@ def cancel_indexing_attempts_for_ccpair(
|
||||
include_secondary_index: bool = False,
|
||||
) -> None:
|
||||
stmt = (
|
||||
update(IndexAttempt)
|
||||
delete(IndexAttempt)
|
||||
.where(IndexAttempt.connector_credential_pair_id == cc_pair_id)
|
||||
.where(IndexAttempt.status == IndexingStatus.NOT_STARTED)
|
||||
.values(
|
||||
status=IndexingStatus.CANCELED,
|
||||
error_msg="Canceled by user",
|
||||
time_started=datetime.now(timezone.utc),
|
||||
)
|
||||
)
|
||||
|
||||
if not include_secondary_index:
|
||||
|
||||
202
backend/danswer/db/input_prompt.py
Normal file
202
backend/danswer/db/input_prompt.py
Normal file
@@ -0,0 +1,202 @@
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.manage.models import UserInfo
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def insert_input_prompt_if_not_exists(
|
||||
user: User | None,
|
||||
input_prompt_id: int | None,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
is_public: bool,
|
||||
db_session: Session,
|
||||
commit: bool = True,
|
||||
) -> InputPrompt:
|
||||
if input_prompt_id is not None:
|
||||
input_prompt = (
|
||||
db_session.query(InputPrompt).filter_by(id=input_prompt_id).first()
|
||||
)
|
||||
else:
|
||||
query = db_session.query(InputPrompt).filter(InputPrompt.prompt == prompt)
|
||||
if user:
|
||||
query = query.filter(InputPrompt.user_id == user.id)
|
||||
else:
|
||||
query = query.filter(InputPrompt.user_id.is_(None))
|
||||
input_prompt = query.first()
|
||||
|
||||
if input_prompt is None:
|
||||
input_prompt = InputPrompt(
|
||||
id=input_prompt_id,
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=active,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
|
||||
if commit:
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def insert_input_prompt(
|
||||
prompt: str,
|
||||
content: str,
|
||||
is_public: bool,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = InputPrompt(
|
||||
prompt=prompt,
|
||||
content=content,
|
||||
active=True,
|
||||
is_public=is_public or user is None,
|
||||
user_id=user.id if user is not None else None,
|
||||
)
|
||||
db_session.add(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
return input_prompt
|
||||
|
||||
|
||||
def update_input_prompt(
|
||||
user: User | None,
|
||||
input_prompt_id: int,
|
||||
prompt: str,
|
||||
content: str,
|
||||
active: bool,
|
||||
db_session: Session,
|
||||
) -> InputPrompt:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You don't own this prompt")
|
||||
|
||||
input_prompt.prompt = prompt
|
||||
input_prompt.content = content
|
||||
input_prompt.active = active
|
||||
|
||||
db_session.commit()
|
||||
return input_prompt
|
||||
|
||||
|
||||
def validate_user_prompt_authorization(
|
||||
user: User | None, input_prompt: InputPrompt
|
||||
) -> bool:
|
||||
prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
if prompt.user_id is not None:
|
||||
if user is None:
|
||||
return False
|
||||
|
||||
user_details = UserInfo.from_model(user)
|
||||
if str(user_details.id) != str(prompt.user_id):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if not input_prompt.is_public:
|
||||
raise HTTPException(status_code=400, detail="This prompt is not public")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_input_prompt(
|
||||
user: User | None, input_prompt_id: int, db_session: Session
|
||||
) -> None:
|
||||
input_prompt = db_session.scalar(
|
||||
select(InputPrompt).where(InputPrompt.id == input_prompt_id)
|
||||
)
|
||||
if input_prompt is None:
|
||||
raise ValueError(f"No input prompt with id {input_prompt_id}")
|
||||
|
||||
if input_prompt.is_public:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Cannot delete public prompts with this method"
|
||||
)
|
||||
|
||||
if not validate_user_prompt_authorization(user, input_prompt):
|
||||
raise HTTPException(status_code=401, detail="You do not own this prompt")
|
||||
|
||||
db_session.delete(input_prompt)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def fetch_input_prompt_by_id(
|
||||
id: int, user_id: UUID | None, db_session: Session
|
||||
) -> InputPrompt:
|
||||
query = select(InputPrompt).where(InputPrompt.id == id)
|
||||
|
||||
if user_id:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | (InputPrompt.user_id is None)
|
||||
)
|
||||
else:
|
||||
# If no user_id is provided, only fetch prompts without a user_id (aka public)
|
||||
query = query.where(InputPrompt.user_id == None) # noqa
|
||||
|
||||
result = db_session.scalar(query)
|
||||
|
||||
if result is None:
|
||||
raise HTTPException(422, "No input prompt found")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fetch_public_input_prompts(
|
||||
db_session: Session,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt).where(InputPrompt.is_public)
|
||||
return list(db_session.scalars(query).all())
|
||||
|
||||
|
||||
def fetch_input_prompts_by_user(
|
||||
db_session: Session,
|
||||
user_id: UUID | None,
|
||||
active: bool | None = None,
|
||||
include_public: bool = False,
|
||||
) -> list[InputPrompt]:
|
||||
query = select(InputPrompt)
|
||||
|
||||
if user_id is not None:
|
||||
if include_public:
|
||||
query = query.where(
|
||||
(InputPrompt.user_id == user_id) | InputPrompt.is_public
|
||||
)
|
||||
else:
|
||||
query = query.where(InputPrompt.user_id == user_id)
|
||||
|
||||
elif include_public:
|
||||
query = query.where(InputPrompt.is_public)
|
||||
|
||||
if active is not None:
|
||||
query = query.where(InputPrompt.active == active)
|
||||
|
||||
return list(db_session.scalars(query).all())
|
||||
@@ -159,6 +159,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
|
||||
prompts: Mapped[list["Prompt"]] = relationship("Prompt", back_populates="user")
|
||||
input_prompts: Mapped[list["InputPrompt"]] = relationship(
|
||||
"InputPrompt", back_populates="user"
|
||||
)
|
||||
|
||||
# Personas owned by this user
|
||||
personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user")
|
||||
@@ -175,6 +178,31 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt(Base):
|
||||
__tablename__ = "inputprompt"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
prompt: Mapped[str] = mapped_column(String)
|
||||
content: Mapped[str] = mapped_column(String)
|
||||
active: Mapped[bool] = mapped_column(Boolean)
|
||||
user: Mapped[User | None] = relationship("User", back_populates="input_prompts")
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class InputPrompt__User(Base):
|
||||
__tablename__ = "inputprompt__user"
|
||||
|
||||
input_prompt_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("inputprompt.id"), primary_key=True
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
pass
|
||||
|
||||
@@ -568,25 +596,6 @@ class Connector(Base):
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
|
||||
|
||||
# synchronize this validation logic with RefreshFrequencySchema etc on front end
|
||||
# until we have a centralized validation schema
|
||||
|
||||
# TODO(rkuo): experiment with SQLAlchemy validators rather than manual checks
|
||||
# https://docs.sqlalchemy.org/en/20/orm/mapped_attributes.html
|
||||
def validate_refresh_freq(self) -> None:
|
||||
if self.refresh_freq is not None:
|
||||
if self.refresh_freq < 60:
|
||||
raise ValueError(
|
||||
"refresh_freq must be greater than or equal to 60 seconds."
|
||||
)
|
||||
|
||||
def validate_prune_freq(self) -> None:
|
||||
if self.prune_freq is not None:
|
||||
if self.prune_freq < 86400:
|
||||
raise ValueError(
|
||||
"prune_freq must be greater than or equal to 86400 seconds."
|
||||
)
|
||||
|
||||
|
||||
class Credential(Base):
|
||||
__tablename__ = "credential"
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_file_ext(file_path_or_name: str | Path) -> str:
|
||||
return extension
|
||||
|
||||
|
||||
def is_valid_file_ext(ext: str) -> bool:
|
||||
def check_file_ext_is_valid(ext: str) -> bool:
|
||||
return ext in VALID_FILE_EXTENSIONS
|
||||
|
||||
|
||||
@@ -364,7 +364,7 @@ def extract_file_text(
|
||||
elif file_name is not None:
|
||||
final_extension = get_file_ext(file_name)
|
||||
|
||||
if is_valid_file_ext(final_extension):
|
||||
if check_file_ext_is_valid(final_extension):
|
||||
return extension_to_function.get(final_extension, file_io_to_text)(file)
|
||||
|
||||
# Either the file somehow has no name or the extension is not one that we recognize
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
from http import HTTPStatus
|
||||
from typing import Protocol
|
||||
@@ -13,7 +12,6 @@ from danswer.access.access import get_access_for_documents
|
||||
from danswer.access.models import DocumentAccess
|
||||
from danswer.configs.app_configs import ENABLE_MULTIPASS_INDEXING
|
||||
from danswer.configs.app_configs import INDEXING_EXCEPTION_LIMIT
|
||||
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
@@ -204,13 +202,40 @@ def index_doc_batch_with_handler(
|
||||
|
||||
|
||||
def index_doc_batch_prepare(
|
||||
documents: list[Document],
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
) -> DocumentBatchPrepareContext | None:
|
||||
"""Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc.
|
||||
This preceeds indexing it into the actual document index."""
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
and not document.semantic_identifier.strip()
|
||||
and empty_contents
|
||||
):
|
||||
# Skip documents that have neither title nor content
|
||||
# If the document doesn't have either, then there is no useful information in it
|
||||
# This is again verified later in the pipeline after chunking but at that point there should
|
||||
# already be no documents that are empty.
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content."
|
||||
)
|
||||
continue
|
||||
|
||||
if document.title is not None and not document.title.strip() and empty_contents:
|
||||
# The title is explicitly empty ("" and not None) and the document is empty
|
||||
# so when building the chunk text representation, it will be empty and unuseable
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as the chunks will be empty."
|
||||
)
|
||||
continue
|
||||
|
||||
documents.append(document)
|
||||
|
||||
# Create a trimmed list of docs that don't have a newer updated at
|
||||
# Shortcuts the time-consuming flow on connector index retries
|
||||
document_ids: list[str] = [document.id for document in documents]
|
||||
@@ -257,64 +282,17 @@ def index_doc_batch_prepare(
|
||||
)
|
||||
|
||||
|
||||
def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
and not document.semantic_identifier.strip()
|
||||
and empty_contents
|
||||
):
|
||||
# Skip documents that have neither title nor content
|
||||
# If the document doesn't have either, then there is no useful information in it
|
||||
# This is again verified later in the pipeline after chunking but at that point there should
|
||||
# already be no documents that are empty.
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it has neither title nor content."
|
||||
)
|
||||
continue
|
||||
|
||||
if document.title is not None and not document.title.strip() and empty_contents:
|
||||
# The title is explicitly empty ("" and not None) and the document is empty
|
||||
# so when building the chunk text representation, it will be empty and unuseable
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as the chunks will be empty."
|
||||
)
|
||||
continue
|
||||
|
||||
section_chars = sum(len(section.text) for section in document.sections)
|
||||
if (
|
||||
MAX_DOCUMENT_CHARS
|
||||
and len(document.title or document.semantic_identifier) + section_chars
|
||||
> MAX_DOCUMENT_CHARS
|
||||
):
|
||||
# Skip documents that are too long, later on there are more memory intensive steps done on the text
|
||||
# and the container will run out of memory and crash. Several other checks are included upstream but
|
||||
# those are at the connector level so a catchall is still needed.
|
||||
# Assumption here is that files that are that long, are generated files and not the type users
|
||||
# generally care for.
|
||||
logger.warning(
|
||||
f"Skipping document with ID {document.id} as it is too long."
|
||||
)
|
||||
continue
|
||||
|
||||
documents.append(document)
|
||||
return documents
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
document_batch: list[Document],
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> tuple[int, int]:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
@@ -331,11 +309,8 @@ def index_doc_batch(
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
logger.debug("Filtering Documents")
|
||||
filtered_documents = filter_fnc(document_batch)
|
||||
|
||||
ctx = index_doc_batch_prepare(
|
||||
documents=filtered_documents,
|
||||
document_batch=document_batch,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import copy
|
||||
import io
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
@@ -6,6 +7,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import litellm # type: ignore
|
||||
import pandas as pd
|
||||
import tiktoken
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
from langchain.prompts.chat import ChatPromptValue
|
||||
@@ -98,32 +100,53 @@ def litellm_exception_to_error_msg(
|
||||
return error_msg
|
||||
|
||||
|
||||
# Processes CSV files to show the first 5 rows and max_columns (default 40) columns
|
||||
def _process_csv_file(file: InMemoryChatFile, max_columns: int = 40) -> str:
|
||||
df = pd.read_csv(io.StringIO(file.content.decode("utf-8")))
|
||||
|
||||
csv_preview = df.head().to_string(max_cols=max_columns)
|
||||
|
||||
file_name_section = (
|
||||
f"CSV FILE NAME: {file.filename}\n"
|
||||
if file.filename
|
||||
else "CSV FILE (NO NAME PROVIDED):\n"
|
||||
)
|
||||
return f"{file_name_section}{CODE_BLOCK_PAT.format(csv_preview)}\n\n\n"
|
||||
|
||||
|
||||
def _build_content(
|
||||
message: str,
|
||||
files: list[InMemoryChatFile] | None = None,
|
||||
) -> str:
|
||||
"""Applies all non-image files."""
|
||||
if not files:
|
||||
return message
|
||||
text_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.PLAIN_TEXT]
|
||||
if files
|
||||
else None
|
||||
)
|
||||
|
||||
text_files = [
|
||||
file
|
||||
for file in files
|
||||
if file.file_type in (ChatFileType.PLAIN_TEXT, ChatFileType.CSV)
|
||||
]
|
||||
csv_files = (
|
||||
[file for file in files if file.file_type == ChatFileType.CSV]
|
||||
if files
|
||||
else None
|
||||
)
|
||||
|
||||
if not text_files:
|
||||
if not text_files and not csv_files:
|
||||
return message
|
||||
|
||||
final_message_with_files = "FILES:\n\n"
|
||||
for file in text_files:
|
||||
for file in text_files or []:
|
||||
file_content = file.content.decode("utf-8")
|
||||
file_name_section = f"DOCUMENT: {file.filename}\n" if file.filename else ""
|
||||
final_message_with_files += (
|
||||
f"{file_name_section}{CODE_BLOCK_PAT.format(file_content.strip())}\n\n\n"
|
||||
)
|
||||
for file in csv_files or []:
|
||||
final_message_with_files += _process_csv_file(file)
|
||||
|
||||
return final_message_with_files + message
|
||||
final_message_with_files += message
|
||||
|
||||
return final_message_with_files
|
||||
|
||||
|
||||
def build_content_with_imgs(
|
||||
|
||||
@@ -52,9 +52,12 @@ from danswer.server.documents.connector import router as connector_router
|
||||
from danswer.server.documents.credential import router as credential_router
|
||||
from danswer.server.documents.document import router as document_router
|
||||
from danswer.server.documents.indexing import router as indexing_router
|
||||
from danswer.server.documents.standard_oauth import router as oauth_router
|
||||
from danswer.server.features.document_set.api import router as document_set_router
|
||||
from danswer.server.features.folder.api import router as folder_router
|
||||
from danswer.server.features.input_prompt.api import (
|
||||
admin_router as admin_input_prompt_router,
|
||||
)
|
||||
from danswer.server.features.input_prompt.api import basic_router as input_prompt_router
|
||||
from danswer.server.features.notifications.api import router as notification_router
|
||||
from danswer.server.features.persona.api import admin_router as admin_persona_router
|
||||
from danswer.server.features.persona.api import basic_router as persona_router
|
||||
@@ -255,6 +258,8 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, persona_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_persona_router)
|
||||
include_router_with_global_prefix_prepended(application, input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, notification_router)
|
||||
include_router_with_global_prefix_prepended(application, prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, tool_router)
|
||||
@@ -277,7 +282,6 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
24
backend/danswer/seeding/input_prompts.yaml
Normal file
24
backend/danswer/seeding/input_prompts.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
input_prompts:
|
||||
- id: -5
|
||||
prompt: "Elaborate"
|
||||
content: "Elaborate on the above, give me a more in depth explanation."
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -4
|
||||
prompt: "Reword"
|
||||
content: "Help me rewrite the following politely and concisely for professional communication:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -3
|
||||
prompt: "Email"
|
||||
content: "Write a professional email for me including a subject line, signature, etc. Template the parts that need editing with [ ]. The email should cover the following points:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
|
||||
- id: -2
|
||||
prompt: "Debug"
|
||||
content: "Provide step-by-step troubleshooting instructions for the following issue:\n"
|
||||
active: true
|
||||
is_public: true
|
||||
@@ -196,7 +196,7 @@ def seed_initial_documents(
|
||||
docs, chunks = _create_indexable_chunks(processed_docs, tenant_id)
|
||||
|
||||
index_doc_batch_prepare(
|
||||
documents=docs,
|
||||
document_batch=docs,
|
||||
index_attempt_metadata=IndexAttemptMetadata(
|
||||
connector_id=connector_id,
|
||||
credential_id=PUBLIC_CREDENTIAL_ID,
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import yaml
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.chat_configs import INPUT_PROMPT_YAML
|
||||
from danswer.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from danswer.configs.chat_configs import PERSONAS_YAML
|
||||
from danswer.configs.chat_configs import PROMPTS_YAML
|
||||
from danswer.context.search.enums import RecencyBiasSetting
|
||||
from danswer.db.document_set import get_or_create_document_set_by_name
|
||||
from danswer.db.input_prompt import insert_input_prompt_if_not_exists
|
||||
from danswer.db.models import DocumentSet as DocumentSetDBModel
|
||||
from danswer.db.models import Persona
|
||||
from danswer.db.models import Prompt as PromptDBModel
|
||||
@@ -138,10 +140,35 @@ def load_personas_from_yaml(
|
||||
)
|
||||
|
||||
|
||||
def load_input_prompts_from_yaml(
|
||||
db_session: Session, input_prompts_yaml: str = INPUT_PROMPT_YAML
|
||||
) -> None:
|
||||
with open(input_prompts_yaml, "r") as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
all_input_prompts = data.get("input_prompts", [])
|
||||
for input_prompt in all_input_prompts:
|
||||
# If these prompts are deleted (which is a hard delete in the DB), on server startup
|
||||
# they will be recreated, but the user can always just deactivate them, just a light inconvenience
|
||||
|
||||
insert_input_prompt_if_not_exists(
|
||||
user=None,
|
||||
input_prompt_id=input_prompt.get("id"),
|
||||
prompt=input_prompt["prompt"],
|
||||
content=input_prompt["content"],
|
||||
is_public=input_prompt["is_public"],
|
||||
active=input_prompt.get("active", True),
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
|
||||
def load_chat_yamls(
|
||||
db_session: Session,
|
||||
prompt_yaml: str = PROMPTS_YAML,
|
||||
personas_yaml: str = PERSONAS_YAML,
|
||||
input_prompts_yaml: str = INPUT_PROMPT_YAML,
|
||||
) -> None:
|
||||
load_prompts_from_yaml(db_session, prompt_yaml)
|
||||
load_personas_from_yaml(db_session, personas_yaml)
|
||||
load_input_prompts_from_yaml(db_session, input_prompts_yaml)
|
||||
|
||||
@@ -33,6 +33,8 @@ from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.db.enums import ConnectorCredentialPairStatus
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
from danswer.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from danswer.db.index_attempt import count_index_attempts_for_connector
|
||||
from danswer.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from danswer.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id
|
||||
@@ -43,7 +45,6 @@ from danswer.db.search_settings import get_current_search_settings
|
||||
from danswer.redis.redis_connector import RedisConnector
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CCPairFullInfo
|
||||
from danswer.server.documents.models import CCPropertyUpdateRequest
|
||||
from danswer.server.documents.models import CCStatusUpdateRequest
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.documents.models import ConnectorCredentialPairMetadata
|
||||
@@ -191,6 +192,9 @@ def update_cc_pair_status(
|
||||
db_session
|
||||
)
|
||||
|
||||
cancel_indexing_attempts_for_ccpair(cc_pair_id, db_session)
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
try:
|
||||
@@ -304,46 +308,6 @@ def update_cc_pair_name(
|
||||
raise HTTPException(status_code=400, detail="Name must be unique")
|
||||
|
||||
|
||||
@router.put("/admin/cc-pair/{cc_pair_id}/property")
|
||||
def update_cc_pair_property(
|
||||
cc_pair_id: int,
|
||||
update_request: CCPropertyUpdateRequest, # in seconds
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StatusResponse[int]:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
cc_pair_id=cc_pair_id,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
get_editable=True,
|
||||
)
|
||||
if not cc_pair:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="CC Pair not found for current user's permissions"
|
||||
)
|
||||
|
||||
# Can we centralize logic for updating connector properties
|
||||
# so that we don't need to manually validate everywhere?
|
||||
if update_request.name == "refresh_frequency":
|
||||
cc_pair.connector.refresh_freq = int(update_request.value)
|
||||
cc_pair.connector.validate_refresh_freq()
|
||||
db_session.commit()
|
||||
|
||||
msg = "Refresh frequency updated successfully"
|
||||
elif update_request.name == "pruning_frequency":
|
||||
cc_pair.connector.prune_freq = int(update_request.value)
|
||||
cc_pair.connector.validate_prune_freq()
|
||||
db_session.commit()
|
||||
|
||||
msg = "Pruning frequency updated successfully"
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=400, detail=f"Property name {update_request.name} is not valid."
|
||||
)
|
||||
|
||||
return StatusResponse(success=True, message=msg, data=cc_pair_id)
|
||||
|
||||
|
||||
@router.get("/admin/cc-pair/{cc_pair_id}/last_pruned")
|
||||
def get_cc_pair_last_pruned(
|
||||
cc_pair_id: int,
|
||||
|
||||
@@ -181,13 +181,7 @@ def update_credential_data(
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CredentialBase:
|
||||
credential = alter_credential(
|
||||
credential_id,
|
||||
credential_update.name,
|
||||
credential_update.credential_json,
|
||||
user,
|
||||
db_session,
|
||||
)
|
||||
credential = alter_credential(credential_id, credential_update, user, db_session)
|
||||
|
||||
if credential is None:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -364,11 +364,6 @@ class RunConnectorRequest(BaseModel):
|
||||
from_beginning: bool = False
|
||||
|
||||
|
||||
class CCPropertyUpdateRequest(BaseModel):
|
||||
name: str
|
||||
value: str
|
||||
|
||||
|
||||
"""Connectors Models"""
|
||||
|
||||
|
||||
|
||||
@@ -1,142 +0,0 @@
|
||||
import uuid
|
||||
from typing import Annotated
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.interfaces import OAuthConnector
|
||||
from danswer.db.credentials import create_credential
|
||||
from danswer.db.engine import get_current_tenant_id
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.models import User
|
||||
from danswer.redis.redis_pool import get_redis_client
|
||||
from danswer.server.documents.models import CredentialBase
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.subclasses import find_all_subclasses_in_dir
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/connector/oauth")
|
||||
|
||||
_OAUTH_STATE_KEY_FMT = "oauth_state:{state}"
|
||||
_OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes
|
||||
|
||||
# Cache for OAuth connectors, populated at module load time
|
||||
_OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {}
|
||||
|
||||
|
||||
def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]:
|
||||
"""Walk through the connectors package to find all OAuthConnector implementations"""
|
||||
global _OAUTH_CONNECTORS
|
||||
if _OAUTH_CONNECTORS: # Return cached connectors if already discovered
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
oauth_connectors = find_all_subclasses_in_dir(
|
||||
cast(type[OAuthConnector], OAuthConnector), "danswer.connectors"
|
||||
)
|
||||
|
||||
_OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors}
|
||||
return _OAUTH_CONNECTORS
|
||||
|
||||
|
||||
# Discover OAuth connectors at module load time
|
||||
_discover_oauth_connectors()
|
||||
|
||||
|
||||
class AuthorizeResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.get("/authorize/{source}")
|
||||
def oauth_authorize(
|
||||
source: DocumentSource,
|
||||
desired_return_url: Annotated[str | None, Query()] = None,
|
||||
_: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> AuthorizeResponse:
|
||||
"""Initiates the OAuth flow by redirecting to the provider's auth page"""
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
base_url = WEB_DOMAIN
|
||||
|
||||
# store state in redis
|
||||
if not desired_return_url:
|
||||
desired_return_url = f"{base_url}/admin/connectors/{source}?step=0"
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
state = str(uuid.uuid4())
|
||||
redis_client.set(
|
||||
_OAUTH_STATE_KEY_FMT.format(state=state),
|
||||
desired_return_url,
|
||||
ex=_OAUTH_STATE_EXPIRATION_SECONDS,
|
||||
)
|
||||
|
||||
return AuthorizeResponse(
|
||||
redirect_url=connector_cls.oauth_authorization_url(base_url, state)
|
||||
)
|
||||
|
||||
|
||||
class CallbackResponse(BaseModel):
|
||||
redirect_url: str
|
||||
|
||||
|
||||
@router.get("/callback/{source}")
|
||||
def oauth_callback(
|
||||
source: DocumentSource,
|
||||
code: Annotated[str, Query()],
|
||||
state: Annotated[str, Query()],
|
||||
db_session: Session = Depends(get_session),
|
||||
user: User = Depends(current_user),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> CallbackResponse:
|
||||
"""Handles the OAuth callback and exchanges the code for tokens"""
|
||||
oauth_connectors = _discover_oauth_connectors()
|
||||
|
||||
if source not in oauth_connectors:
|
||||
raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}")
|
||||
|
||||
connector_cls = oauth_connectors[source]
|
||||
|
||||
# get state from redis
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
original_url_bytes = cast(
|
||||
bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state))
|
||||
)
|
||||
if not original_url_bytes:
|
||||
raise HTTPException(status_code=400, detail="Invalid OAuth state")
|
||||
original_url = original_url_bytes.decode("utf-8")
|
||||
|
||||
token_info = connector_cls.oauth_code_to_token(code)
|
||||
|
||||
# Create a new credential with the token info
|
||||
credential_data = CredentialBase(
|
||||
credential_json=token_info,
|
||||
admin_public=True, # Or based on some logic/parameter
|
||||
source=source,
|
||||
name=f"{source.title()} OAuth Credential",
|
||||
)
|
||||
|
||||
credential = create_credential(
|
||||
credential_data=credential_data,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
return CallbackResponse(
|
||||
redirect_url=(
|
||||
f"{original_url}?credentialId={credential.id}"
|
||||
if "?" not in original_url
|
||||
else f"{original_url}&credentialId={credential.id}"
|
||||
)
|
||||
)
|
||||
134
backend/danswer/server/features/input_prompt/api.py
Normal file
134
backend/danswer/server/features/input_prompt/api.py
Normal file
@@ -0,0 +1,134 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.input_prompt import fetch_input_prompt_by_id
|
||||
from danswer.db.input_prompt import fetch_input_prompts_by_user
|
||||
from danswer.db.input_prompt import fetch_public_input_prompts
|
||||
from danswer.db.input_prompt import insert_input_prompt
|
||||
from danswer.db.input_prompt import remove_input_prompt
|
||||
from danswer.db.input_prompt import remove_public_input_prompt
|
||||
from danswer.db.input_prompt import update_input_prompt
|
||||
from danswer.db.models import User
|
||||
from danswer.server.features.input_prompt.models import CreateInputPromptRequest
|
||||
from danswer.server.features.input_prompt.models import InputPromptSnapshot
|
||||
from danswer.server.features.input_prompt.models import UpdateInputPromptRequest
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
basic_router = APIRouter(prefix="/input_prompt")
|
||||
admin_router = APIRouter(prefix="/admin/input_prompt")
|
||||
|
||||
|
||||
@basic_router.get("")
|
||||
def list_input_prompts(
|
||||
user: User | None = Depends(current_user),
|
||||
include_public: bool = False,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InputPromptSnapshot]:
|
||||
user_prompts = fetch_input_prompts_by_user(
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
include_public=include_public,
|
||||
)
|
||||
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
|
||||
|
||||
|
||||
@basic_router.get("/{input_prompt_id}")
|
||||
def get_input_prompt(
|
||||
input_prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
input_prompt = fetch_input_prompt_by_id(
|
||||
id=input_prompt_id,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
return InputPromptSnapshot.from_model(input_prompt=input_prompt)
|
||||
|
||||
|
||||
@basic_router.post("")
|
||||
def create_input_prompt(
|
||||
create_input_prompt_request: CreateInputPromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
input_prompt = insert_input_prompt(
|
||||
prompt=create_input_prompt_request.prompt,
|
||||
content=create_input_prompt_request.content,
|
||||
is_public=create_input_prompt_request.is_public,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
)
|
||||
return InputPromptSnapshot.from_model(input_prompt)
|
||||
|
||||
|
||||
@basic_router.patch("/{input_prompt_id}")
|
||||
def patch_input_prompt(
|
||||
input_prompt_id: int,
|
||||
update_input_prompt_request: UpdateInputPromptRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> InputPromptSnapshot:
|
||||
try:
|
||||
updated_input_prompt = update_input_prompt(
|
||||
user=user,
|
||||
input_prompt_id=input_prompt_id,
|
||||
prompt=update_input_prompt_request.prompt,
|
||||
content=update_input_prompt_request.content,
|
||||
active=update_input_prompt_request.active,
|
||||
db_session=db_session,
|
||||
)
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while updated input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
return InputPromptSnapshot.from_model(updated_input_prompt)
|
||||
|
||||
|
||||
@basic_router.delete("/{input_prompt_id}")
|
||||
def delete_input_prompt(
|
||||
input_prompt_id: int,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
remove_input_prompt(user, input_prompt_id, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while deleting input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.delete("/{input_prompt_id}")
|
||||
def delete_public_input_prompt(
|
||||
input_prompt_id: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
remove_public_input_prompt(input_prompt_id, db_session)
|
||||
|
||||
except ValueError as e:
|
||||
error_msg = "Error occurred while deleting input prompt"
|
||||
logger.warn(f"{error_msg}. Stack trace: {e}")
|
||||
raise HTTPException(status_code=404, detail=error_msg)
|
||||
|
||||
|
||||
@admin_router.get("")
|
||||
def list_public_input_prompts(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[InputPromptSnapshot]:
|
||||
user_prompts = fetch_public_input_prompts(
|
||||
db_session=db_session,
|
||||
)
|
||||
return [InputPromptSnapshot.from_model(prompt) for prompt in user_prompts]
|
||||
47
backend/danswer/server/features/input_prompt/models.py
Normal file
47
backend/danswer/server/features/input_prompt/models.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from danswer.db.models import InputPrompt
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class CreateInputPromptRequest(BaseModel):
|
||||
prompt: str
|
||||
content: str
|
||||
is_public: bool
|
||||
|
||||
|
||||
class UpdateInputPromptRequest(BaseModel):
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
|
||||
|
||||
class InputPromptResponse(BaseModel):
|
||||
id: int
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
|
||||
|
||||
class InputPromptSnapshot(BaseModel):
|
||||
id: int
|
||||
prompt: str
|
||||
content: str
|
||||
active: bool
|
||||
user_id: UUID | None
|
||||
is_public: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, input_prompt: InputPrompt) -> "InputPromptSnapshot":
|
||||
return InputPromptSnapshot(
|
||||
id=input_prompt.id,
|
||||
prompt=input_prompt.prompt,
|
||||
content=input_prompt.content,
|
||||
active=input_prompt.active,
|
||||
user_id=input_prompt.user_id,
|
||||
is_public=input_prompt.is_public,
|
||||
)
|
||||
@@ -266,7 +266,5 @@ class FullModelVersionResponse(BaseModel):
|
||||
class AllUsersResponse(BaseModel):
|
||||
accepted: list[FullUserSnapshot]
|
||||
invited: list[InvitedUserSnapshot]
|
||||
slack_users: list[FullUserSnapshot]
|
||||
accepted_pages: int
|
||||
invited_pages: int
|
||||
slack_users_pages: int
|
||||
|
||||
@@ -119,7 +119,6 @@ def set_user_role(
|
||||
def list_all_users(
|
||||
q: str | None = None,
|
||||
accepted_page: int | None = None,
|
||||
slack_users_page: int | None = None,
|
||||
invited_page: int | None = None,
|
||||
user: User | None = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -132,12 +131,7 @@ def list_all_users(
|
||||
for user in list_users(db_session, email_filter_string=q)
|
||||
if not is_api_key_email_address(user.email)
|
||||
]
|
||||
|
||||
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
|
||||
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
|
||||
|
||||
accepted_emails = {user.email for user in accepted_users}
|
||||
slack_users_emails = {user.email for user in slack_users}
|
||||
accepted_emails = {user.email for user in users}
|
||||
invited_emails = get_invited_users()
|
||||
if q:
|
||||
invited_emails = [
|
||||
@@ -145,11 +139,10 @@ def list_all_users(
|
||||
]
|
||||
|
||||
accepted_count = len(accepted_emails)
|
||||
slack_users_count = len(slack_users_emails)
|
||||
invited_count = len(invited_emails)
|
||||
|
||||
# If any of q, accepted_page, or invited_page is None, return all users
|
||||
if accepted_page is None or invited_page is None or slack_users_page is None:
|
||||
if accepted_page is None or invited_page is None:
|
||||
return AllUsersResponse(
|
||||
accepted=[
|
||||
FullUserSnapshot(
|
||||
@@ -160,23 +153,11 @@ def list_all_users(
|
||||
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
|
||||
),
|
||||
)
|
||||
for user in accepted_users
|
||||
],
|
||||
slack_users=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
status=(
|
||||
UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED
|
||||
),
|
||||
)
|
||||
for user in slack_users
|
||||
for user in users
|
||||
],
|
||||
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
|
||||
accepted_pages=1,
|
||||
invited_pages=1,
|
||||
slack_users_pages=1,
|
||||
)
|
||||
|
||||
# Otherwise, return paginated results
|
||||
@@ -188,27 +169,13 @@ def list_all_users(
|
||||
role=user.role,
|
||||
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
|
||||
)
|
||||
for user in accepted_users
|
||||
for user in users
|
||||
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
|
||||
slack_users=[
|
||||
FullUserSnapshot(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
status=UserStatus.LIVE if user.is_active else UserStatus.DEACTIVATED,
|
||||
)
|
||||
for user in slack_users
|
||||
][
|
||||
slack_users_page
|
||||
* USERS_PAGE_SIZE : (slack_users_page + 1)
|
||||
* USERS_PAGE_SIZE
|
||||
],
|
||||
invited=[InvitedUserSnapshot(email=email) for email in invited_emails][
|
||||
invited_page * USERS_PAGE_SIZE : (invited_page + 1) * USERS_PAGE_SIZE
|
||||
],
|
||||
accepted_pages=accepted_count // USERS_PAGE_SIZE + 1,
|
||||
invited_pages=invited_count // USERS_PAGE_SIZE + 1,
|
||||
slack_users_pages=slack_users_count // USERS_PAGE_SIZE + 1,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -39,6 +39,7 @@ from danswer.key_value_store.interface import KvKeyNotFoundError
|
||||
from danswer.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from danswer.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
from danswer.seeding.load_docs import seed_initial_documents
|
||||
from danswer.seeding.load_yamls import load_chat_yamls
|
||||
from danswer.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from danswer.server.settings.store import load_settings
|
||||
@@ -150,7 +151,7 @@ def setup_danswer(
|
||||
# update multipass indexing setting based on GPU availability
|
||||
update_default_multipass_indexing(db_session)
|
||||
|
||||
# seed_initial_documents(db_session, tenant_id, cohere_enabled)
|
||||
seed_initial_documents(db_session, tenant_id, cohere_enabled)
|
||||
|
||||
|
||||
def translate_saved_search_settings(db_session: Session) -> None:
|
||||
|
||||
@@ -48,9 +48,6 @@ from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.tools.tool_implementations.search_like_tool_utils import (
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.special_types import JSON_ro
|
||||
|
||||
@@ -394,35 +391,15 @@ class SearchTool(Tool):
|
||||
"""Other utility functions"""
|
||||
|
||||
@classmethod
|
||||
def get_search_result(
|
||||
cls, llm_call: LLMCall
|
||||
) -> tuple[list[LlmDoc], dict[str, int]] | None:
|
||||
"""
|
||||
Returns the final search results and a map of docs to their original search rank (which is what is displayed to user)
|
||||
"""
|
||||
def get_search_result(cls, llm_call: LLMCall) -> list[LlmDoc] | None:
|
||||
if not llm_call.tool_call_info:
|
||||
return None
|
||||
|
||||
final_search_results = []
|
||||
doc_id_to_original_search_rank_map = {}
|
||||
|
||||
for yield_item in llm_call.tool_call_info:
|
||||
if (
|
||||
isinstance(yield_item, ToolResponse)
|
||||
and yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID
|
||||
):
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif (
|
||||
isinstance(yield_item, ToolResponse)
|
||||
and yield_item.id == ORIGINAL_CONTEXT_DOCUMENTS_ID
|
||||
):
|
||||
search_contexts = yield_item.response.contexts
|
||||
original_doc_search_rank = 1
|
||||
for idx, doc in enumerate(search_contexts):
|
||||
if doc.document_id not in doc_id_to_original_search_rank_map:
|
||||
doc_id_to_original_search_rank_map[
|
||||
doc.document_id
|
||||
] = original_doc_search_rank
|
||||
original_doc_search_rank += 1
|
||||
return cast(list[LlmDoc], yield_item.response)
|
||||
|
||||
return final_search_results, doc_id_to_original_search_rank_map
|
||||
return None
|
||||
|
||||
@@ -15,7 +15,6 @@ from danswer.tools.message import ToolCallSummary
|
||||
from danswer.tools.models import ToolResponse
|
||||
|
||||
|
||||
ORIGINAL_CONTEXT_DOCUMENTS_ID = "search_doc_content"
|
||||
FINAL_CONTEXT_DOCUMENTS_ID = "final_context_documents"
|
||||
|
||||
|
||||
|
||||
@@ -1,77 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import os
|
||||
import pkgutil
|
||||
import sys
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from typing import Type
|
||||
from typing import TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def import_all_modules_from_dir(dir_path: str) -> List[ModuleType]:
|
||||
"""
|
||||
Imports all modules found in the given directory and its subdirectories,
|
||||
returning a list of imported module objects.
|
||||
"""
|
||||
dir_path = os.path.abspath(dir_path)
|
||||
|
||||
if dir_path not in sys.path:
|
||||
sys.path.insert(0, dir_path)
|
||||
|
||||
imported_modules: List[ModuleType] = []
|
||||
|
||||
for _, package_name, _ in pkgutil.walk_packages([dir_path]):
|
||||
try:
|
||||
module = importlib.import_module(package_name)
|
||||
imported_modules.append(module)
|
||||
except Exception as e:
|
||||
# Handle or log exceptions as needed
|
||||
print(f"Could not import {package_name}: {e}")
|
||||
|
||||
return imported_modules
|
||||
|
||||
|
||||
def all_subclasses(cls: Type[T]) -> List[Type[T]]:
|
||||
"""
|
||||
Recursively find all subclasses of the given class.
|
||||
"""
|
||||
direct_subs = cls.__subclasses__()
|
||||
result: List[Type[T]] = []
|
||||
for subclass in direct_subs:
|
||||
result.append(subclass)
|
||||
# Extend the result by recursively calling all_subclasses
|
||||
result.extend(all_subclasses(subclass))
|
||||
return result
|
||||
|
||||
|
||||
def find_all_subclasses_in_dir(parent_class: Type[T], directory: str) -> List[Type[T]]:
|
||||
"""
|
||||
Imports all modules from the given directory (and subdirectories),
|
||||
then returns all classes that are subclasses of parent_class.
|
||||
|
||||
:param parent_class: The class to find subclasses of.
|
||||
:param directory: The directory to search for subclasses.
|
||||
:return: A list of all subclasses of parent_class found in the directory.
|
||||
"""
|
||||
# First import all modules to ensure classes are loaded into memory
|
||||
import_all_modules_from_dir(directory)
|
||||
|
||||
# Gather all subclasses of the given parent class
|
||||
subclasses = all_subclasses(parent_class)
|
||||
return subclasses
|
||||
|
||||
|
||||
# Example usage:
|
||||
if __name__ == "__main__":
|
||||
|
||||
class Animal:
|
||||
pass
|
||||
|
||||
# Suppose "mymodules" contains files that define classes inheriting from Animal
|
||||
found_subclasses = find_all_subclasses_in_dir(Animal, "mymodules")
|
||||
for sc in found_subclasses:
|
||||
print("Found subclass:", sc.__name__)
|
||||
@@ -76,7 +76,7 @@ def replace_user__ext_group_for_cc_pair(
|
||||
new_external_permissions = []
|
||||
for external_group in group_defs:
|
||||
for user_email in external_group.user_emails:
|
||||
user_id = email_id_map.get(user_email.lower())
|
||||
user_id = email_id_map.get(user_email)
|
||||
if user_id is None:
|
||||
logger.warning(
|
||||
f"User in group {external_group.id}"
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
from types import TracebackType
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
@@ -8,11 +6,11 @@ import httpx
|
||||
import openai
|
||||
import vertexai # type: ignore
|
||||
import voyageai # type: ignore
|
||||
from cohere import AsyncClient as CohereAsyncClient
|
||||
from cohere import Client as CohereClient
|
||||
from fastapi import APIRouter
|
||||
from fastapi import HTTPException
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
from litellm import aembedding
|
||||
from litellm import embedding
|
||||
from litellm.exceptions import RateLimitError
|
||||
from retry import retry
|
||||
from sentence_transformers import CrossEncoder # type: ignore
|
||||
@@ -65,31 +63,22 @@ class CloudEmbedding:
|
||||
provider: EmbeddingProvider,
|
||||
api_url: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: int = API_BASED_EMBEDDING_TIMEOUT,
|
||||
) -> None:
|
||||
self.provider = provider
|
||||
self.api_key = api_key
|
||||
self.api_url = api_url
|
||||
self.api_version = api_version
|
||||
self.timeout = timeout
|
||||
self.http_client = httpx.AsyncClient(timeout=timeout)
|
||||
self._closed = False
|
||||
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
def _embed_openai(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_OPENAI_MODEL
|
||||
|
||||
# Use the OpenAI specific timeout for this one
|
||||
client = openai.AsyncOpenAI(
|
||||
api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT
|
||||
)
|
||||
client = openai.OpenAI(api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(input=text_batch, model=model)
|
||||
response = client.embeddings.create(input=text_batch, model=model)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
@@ -104,19 +93,19 @@ class CloudEmbedding:
|
||||
logger.error(error_string)
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
async def _embed_cohere(
|
||||
def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_COHERE_MODEL
|
||||
|
||||
client = CohereAsyncClient(api_key=self.api_key)
|
||||
client = CohereClient(api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN):
|
||||
# Does not use the same tokenizer as the Danswer API server but it's approximately the same
|
||||
# empirically it's only off by a very few tokens so it's not a big deal
|
||||
response = await client.embed(
|
||||
response = client.embed(
|
||||
texts=text_batch,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
@@ -125,29 +114,26 @@ class CloudEmbedding:
|
||||
final_embeddings.extend(cast(list[Embedding], response.embeddings))
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_voyage(
|
||||
def _embed_voyage(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
model = DEFAULT_VOYAGE_MODEL
|
||||
|
||||
client = voyageai.AsyncClient(
|
||||
client = voyageai.Client(
|
||||
api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT
|
||||
)
|
||||
|
||||
response = await client.embed(
|
||||
texts=texts,
|
||||
response = client.embed(
|
||||
texts,
|
||||
model=model,
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
return response.embeddings
|
||||
|
||||
async def _embed_azure(
|
||||
self, texts: list[str], model: str | None
|
||||
) -> list[Embedding]:
|
||||
response = await aembedding(
|
||||
def _embed_azure(self, texts: list[str], model: str | None) -> list[Embedding]:
|
||||
response = embedding(
|
||||
model=model,
|
||||
input=texts,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
@@ -156,9 +142,10 @@ class CloudEmbedding:
|
||||
api_version=self.api_version,
|
||||
)
|
||||
embeddings = [embedding["embedding"] for embedding in response.data]
|
||||
|
||||
return embeddings
|
||||
|
||||
async def _embed_vertex(
|
||||
def _embed_vertex(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
) -> list[Embedding]:
|
||||
if not model:
|
||||
@@ -171,7 +158,7 @@ class CloudEmbedding:
|
||||
vertexai.init(project=project_id, credentials=credentials)
|
||||
client = TextEmbeddingModel.from_pretrained(model)
|
||||
|
||||
embeddings = await client.get_embeddings_async(
|
||||
embeddings = client.get_embeddings(
|
||||
[
|
||||
TextEmbeddingInput(
|
||||
text,
|
||||
@@ -179,11 +166,11 @@ class CloudEmbedding:
|
||||
)
|
||||
for text in texts
|
||||
],
|
||||
auto_truncate=True, # This is the default
|
||||
auto_truncate=True, # Also this is default
|
||||
)
|
||||
return [embedding.values for embedding in embeddings]
|
||||
|
||||
async def _embed_litellm_proxy(
|
||||
def _embed_litellm_proxy(
|
||||
self, texts: list[str], model_name: str | None
|
||||
) -> list[Embedding]:
|
||||
if not model_name:
|
||||
@@ -196,20 +183,22 @@ class CloudEmbedding:
|
||||
{} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"}
|
||||
)
|
||||
|
||||
response = await self.http_client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
self.api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
"input": texts,
|
||||
},
|
||||
headers=headers,
|
||||
timeout=API_BASED_EMBEDDING_TIMEOUT,
|
||||
)
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
return [embedding["embedding"] for embedding in result["data"]]
|
||||
|
||||
@retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY)
|
||||
async def embed(
|
||||
def embed(
|
||||
self,
|
||||
*,
|
||||
texts: list[str],
|
||||
@@ -218,19 +207,19 @@ class CloudEmbedding:
|
||||
deployment_name: str | None = None,
|
||||
) -> list[Embedding]:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name)
|
||||
return self._embed_openai(texts, model_name)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
return self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
return self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
return self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
return self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
return self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
|
||||
@@ -244,30 +233,6 @@ class CloudEmbedding:
|
||||
logger.debug(f"Creating Embedding instance for provider: {provider}")
|
||||
return CloudEmbedding(api_key, provider, api_url, api_version)
|
||||
|
||||
async def aclose(self) -> None:
|
||||
"""Explicitly close the client."""
|
||||
if not self._closed:
|
||||
await self.http_client.aclose()
|
||||
self._closed = True
|
||||
|
||||
async def __aenter__(self) -> "CloudEmbedding":
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self,
|
||||
exc_type: type[BaseException] | None,
|
||||
exc_val: BaseException | None,
|
||||
exc_tb: TracebackType | None,
|
||||
) -> None:
|
||||
await self.aclose()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""Finalizer to warn about unclosed clients."""
|
||||
if not self._closed:
|
||||
logger.warning(
|
||||
"CloudEmbedding was not properly closed. Use 'async with' or call aclose()"
|
||||
)
|
||||
|
||||
|
||||
def get_embedding_model(
|
||||
model_name: str,
|
||||
@@ -277,6 +242,9 @@ def get_embedding_model(
|
||||
|
||||
global _GLOBAL_MODELS_DICT # A dictionary to store models
|
||||
|
||||
if _GLOBAL_MODELS_DICT is None:
|
||||
_GLOBAL_MODELS_DICT = {}
|
||||
|
||||
if model_name not in _GLOBAL_MODELS_DICT:
|
||||
logger.notice(f"Loading {model_name}")
|
||||
# Some model architectures that aren't built into the Transformers or Sentence
|
||||
@@ -307,7 +275,7 @@ def get_local_reranking_model(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
async def embed_text(
|
||||
def embed_text(
|
||||
texts: list[str],
|
||||
text_type: EmbedTextType,
|
||||
model_name: str | None,
|
||||
@@ -343,18 +311,18 @@ async def embed_text(
|
||||
"Cloud models take an explicit text type instead."
|
||||
)
|
||||
|
||||
async with CloudEmbedding(
|
||||
cloud_model = CloudEmbedding(
|
||||
api_key=api_key,
|
||||
provider=provider_type,
|
||||
api_url=api_url,
|
||||
api_version=api_version,
|
||||
) as cloud_model:
|
||||
embeddings = await cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
)
|
||||
embeddings = cloud_model.embed(
|
||||
texts=texts,
|
||||
model_name=model_name,
|
||||
deployment_name=deployment_name,
|
||||
text_type=text_type,
|
||||
)
|
||||
|
||||
if any(embedding is None for embedding in embeddings):
|
||||
error_message = "Embeddings contain None values\n"
|
||||
@@ -370,12 +338,8 @@ async def embed_text(
|
||||
local_model = get_embedding_model(
|
||||
model_name=model_name, max_context_length=max_context_length
|
||||
)
|
||||
# Run CPU-bound embedding in a thread pool
|
||||
embeddings_vectors = await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
),
|
||||
embeddings_vectors = local_model.encode(
|
||||
prefixed_texts, normalize_embeddings=normalize_embeddings
|
||||
)
|
||||
embeddings = [
|
||||
embedding if isinstance(embedding, list) else embedding.tolist()
|
||||
@@ -393,31 +357,27 @@ async def embed_text(
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]:
|
||||
cross_encoder = get_local_reranking_model(model_name)
|
||||
# Run CPU-bound reranking in a thread pool
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
None,
|
||||
lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # type: ignore
|
||||
)
|
||||
return cross_encoder.predict([(query, doc) for doc in docs]).tolist() # type: ignore
|
||||
|
||||
|
||||
async def cohere_rerank(
|
||||
def cohere_rerank(
|
||||
query: str, docs: list[str], model_name: str, api_key: str
|
||||
) -> list[float]:
|
||||
cohere_client = CohereAsyncClient(api_key=api_key)
|
||||
response = await cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
cohere_client = CohereClient(api_key=api_key)
|
||||
response = cohere_client.rerank(query=query, documents=docs, model=model_name)
|
||||
results = response.results
|
||||
sorted_results = sorted(results, key=lambda item: item.index)
|
||||
return [result.relevance_score for result in sorted_results]
|
||||
|
||||
|
||||
async def litellm_rerank(
|
||||
def litellm_rerank(
|
||||
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
|
||||
) -> list[float]:
|
||||
headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
with httpx.Client() as client:
|
||||
response = client.post(
|
||||
api_url,
|
||||
json={
|
||||
"model": model_name,
|
||||
@@ -451,7 +411,7 @@ async def process_embed_request(
|
||||
else:
|
||||
prefix = None
|
||||
|
||||
embeddings = await embed_text(
|
||||
embeddings = embed_text(
|
||||
texts=embed_request.texts,
|
||||
model_name=embed_request.model_name,
|
||||
deployment_name=embed_request.deployment_name,
|
||||
@@ -491,7 +451,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
|
||||
try:
|
||||
if rerank_request.provider_type is None:
|
||||
sim_scores = await local_rerank(
|
||||
sim_scores = local_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
@@ -501,7 +461,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
if rerank_request.api_url is None:
|
||||
raise ValueError("API URL is required for LiteLLM reranking.")
|
||||
|
||||
sim_scores = await litellm_rerank(
|
||||
sim_scores = litellm_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
api_url=rerank_request.api_url,
|
||||
@@ -514,7 +474,7 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
|
||||
elif rerank_request.provider_type == RerankerProvider.COHERE:
|
||||
if rerank_request.api_key is None:
|
||||
raise RuntimeError("Cohere Rerank Requires an API Key")
|
||||
sim_scores = await cohere_rerank(
|
||||
sim_scores = cohere_rerank(
|
||||
query=rerank_request.query,
|
||||
docs=rerank_request.documents,
|
||||
model_name=rerank_request.model_name,
|
||||
|
||||
@@ -6,12 +6,12 @@ router = APIRouter(prefix="/api")
|
||||
|
||||
|
||||
@router.get("/health")
|
||||
async def healthcheck() -> Response:
|
||||
def healthcheck() -> Response:
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@router.get("/gpu-status")
|
||||
async def gpu_status() -> dict[str, bool | str]:
|
||||
def gpu_status() -> dict[str, bool | str]:
|
||||
if torch.cuda.is_available():
|
||||
return {"gpu_available": True, "type": "cuda"}
|
||||
elif torch.backends.mps.is_available():
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
@@ -22,39 +21,21 @@ def simple_log_function_time(
|
||||
include_args: bool = False,
|
||||
) -> Callable[[F], F]:
|
||||
def decorator(func: F) -> F:
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
def wrapped_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
|
||||
@wraps(func)
|
||||
async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = await func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
return result
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_async_func)
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any:
|
||||
start_time = time.time()
|
||||
result = func(*args, **kwargs)
|
||||
elapsed_time_str = str(time.time() - start_time)
|
||||
log_name = func_name or func.__name__
|
||||
args_str = f" args={args} kwargs={kwargs}" if include_args else ""
|
||||
final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds"
|
||||
if debug_only:
|
||||
logger.debug(final_log)
|
||||
else:
|
||||
logger.notice(final_log)
|
||||
return result
|
||||
|
||||
return cast(F, wrapped_sync_func)
|
||||
return cast(F, wrapped_func)
|
||||
|
||||
return decorator
|
||||
|
||||
@@ -29,7 +29,7 @@ trafilatura==1.12.2
|
||||
langchain==0.1.17
|
||||
langchain-core==0.1.50
|
||||
langchain-text-splitters==0.0.1
|
||||
litellm==1.54.1
|
||||
litellm==1.53.1
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
llama-index==0.9.45
|
||||
|
||||
@@ -1,34 +1,30 @@
|
||||
black==23.3.0
|
||||
boto3-stubs[s3]==1.34.133
|
||||
celery-types==0.19.0
|
||||
cohere==5.6.1
|
||||
google-cloud-aiplatform==1.58.0
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
mypy-extensions==1.0.0
|
||||
mypy==1.8.0
|
||||
pandas-stubs==2.2.3.241009
|
||||
pandas==2.2.3
|
||||
pre-commit==3.2.2
|
||||
pytest-asyncio==0.22.0
|
||||
pytest==7.4.4
|
||||
reorder-python-imports==3.9.0
|
||||
ruff==0.0.286
|
||||
sentence-transformers==2.6.1
|
||||
trafilatura==1.12.2
|
||||
types-PyYAML==6.0.12.11
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
types-oauthlib==3.2.0.9
|
||||
types-passlib==1.7.7.20240106
|
||||
types-setuptools==68.0.0.3
|
||||
types-Pillow==10.2.0.20240822
|
||||
types-passlib==1.7.7.20240106
|
||||
types-psutil==5.9.5.17
|
||||
types-psycopg2==2.9.21.10
|
||||
types-python-dateutil==2.8.19.13
|
||||
types-pytz==2023.3.1.1
|
||||
types-PyYAML==6.0.12.11
|
||||
types-regex==2023.3.23.1
|
||||
types-requests==2.28.11.17
|
||||
types-retry==0.9.9.3
|
||||
types-setuptools==68.0.0.3
|
||||
types-urllib3==1.26.25.11
|
||||
voyageai==0.2.3
|
||||
trafilatura==1.12.2
|
||||
lxml==5.3.0
|
||||
lxml_html_clean==0.2.2
|
||||
boto3-stubs[s3]==1.34.133
|
||||
pandas==2.2.3
|
||||
pandas-stubs==2.2.3.241009
|
||||
cohere==5.6.1
|
||||
@@ -12,5 +12,5 @@ torch==2.2.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.54.1
|
||||
litellm==1.50.2
|
||||
sentry-sdk[fastapi,celery,starlette]==2.14.0
|
||||
@@ -69,10 +69,8 @@ class TenantManager:
|
||||
return AllUsersResponse(
|
||||
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
|
||||
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
|
||||
slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]],
|
||||
accepted_pages=data["accepted_pages"],
|
||||
invited_pages=data["invited_pages"],
|
||||
slack_users_pages=data["slack_users_pages"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -130,10 +130,8 @@ class UserManager:
|
||||
all_users = AllUsersResponse(
|
||||
accepted=[FullUserSnapshot(**user) for user in data["accepted"]],
|
||||
invited=[InvitedUserSnapshot(**user) for user in data["invited"]],
|
||||
slack_users=[FullUserSnapshot(**user) for user in data["slack_users"]],
|
||||
accepted_pages=data["accepted_pages"],
|
||||
invited_pages=data["invited_pages"],
|
||||
slack_users_pages=data["slack_users_pages"],
|
||||
)
|
||||
for accepted_user in all_users.accepted:
|
||||
if accepted_user.email == user.email and accepted_user.id == user.id:
|
||||
|
||||
@@ -3,8 +3,6 @@ from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.db.enums import AccessType
|
||||
from danswer.server.documents.models import DocumentSource
|
||||
@@ -25,7 +23,7 @@ from tests.integration.common_utils.vespa import vespa_fixture
|
||||
from tests.integration.connector_job_tests.slack.slack_api_utils import SlackManager
|
||||
|
||||
|
||||
@pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
|
||||
# @pytest.mark.xfail(reason="flaky - see DAN-789 for example", strict=False)
|
||||
def test_slack_permission_sync(
|
||||
reset: None,
|
||||
vespa_client: vespa_fixture,
|
||||
|
||||
@@ -27,6 +27,13 @@ def test_limited(reset: None) -> None:
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# test basic endpoints
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/input_prompt",
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
# test admin endpoints
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/api-key",
|
||||
|
||||
@@ -72,10 +72,8 @@ def process_text(
|
||||
processor = CitationProcessor(
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
display_doc_order_dict=mock_doc_id_to_rank_map,
|
||||
stop_stream=None,
|
||||
)
|
||||
|
||||
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||
for token in tokens:
|
||||
result.extend(processor.process_token(token))
|
||||
@@ -88,7 +86,6 @@ def process_text(
|
||||
final_answer_text += piece.answer_piece or ""
|
||||
elif isinstance(piece, CitationInfo):
|
||||
citations.append(piece)
|
||||
|
||||
return final_answer_text, citations
|
||||
|
||||
|
||||
|
||||
@@ -1,132 +0,0 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from danswer.chat.models import CitationInfo
|
||||
from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import LlmDoc
|
||||
from danswer.chat.stream_processing.citation_processing import CitationProcessor
|
||||
from danswer.chat.stream_processing.utils import DocumentIdOrderMapping
|
||||
from danswer.configs.constants import DocumentSource
|
||||
|
||||
|
||||
"""
|
||||
This module contains tests for the citation extraction functionality in Danswer,
|
||||
specifically the substitution of the number of document cited in the UI. (The LLM
|
||||
will see the sources post re-ranking and relevance check, the UI before these steps.)
|
||||
This module is a derivative of test_citation_processing.py.
|
||||
|
||||
The tests focusses specifically on the substitution of the number of document cited in the UI.
|
||||
|
||||
Key components:
|
||||
- mock_docs: A list of mock LlmDoc objects used for testing.
|
||||
- mock_doc_mapping: A dictionary mapping document IDs to their initial ranks.
|
||||
- mock_doc_mapping_rerank: A dictionary mapping document IDs to their ranks after re-ranking/relevance check.
|
||||
- process_text: A helper function that simulates the citation extraction process.
|
||||
- test_citation_extraction: A parametrized test function covering various citation scenarios.
|
||||
|
||||
To add new test cases:
|
||||
1. Add a new tuple to the @pytest.mark.parametrize decorator of test_citation_extraction.
|
||||
2. Each tuple should contain:
|
||||
- A descriptive test name (string)
|
||||
- Input tokens (list of strings)
|
||||
- Expected output text (string)
|
||||
- Expected citations (list of document IDs)
|
||||
"""
|
||||
|
||||
|
||||
mock_docs = [
|
||||
LlmDoc(
|
||||
document_id=f"doc_{int(id/2)}",
|
||||
content="Document is a doc",
|
||||
blurb=f"Document #{id}",
|
||||
semantic_identifier=f"Doc {id}",
|
||||
source_type=DocumentSource.WEB,
|
||||
metadata={},
|
||||
updated_at=datetime.now(),
|
||||
link=f"https://{int(id/2)}.com" if int(id / 2) % 2 == 0 else None,
|
||||
source_links={0: "https://mintlify.com/docs/settings/broken-links"},
|
||||
match_highlights=[],
|
||||
)
|
||||
for id in range(10)
|
||||
]
|
||||
|
||||
mock_doc_mapping = {
|
||||
"doc_0": 1,
|
||||
"doc_1": 2,
|
||||
"doc_2": 3,
|
||||
"doc_3": 4,
|
||||
"doc_4": 5,
|
||||
"doc_5": 6,
|
||||
}
|
||||
|
||||
mock_doc_mapping_rerank = {
|
||||
"doc_0": 2,
|
||||
"doc_1": 1,
|
||||
"doc_2": 4,
|
||||
"doc_3": 3,
|
||||
"doc_4": 6,
|
||||
"doc_5": 5,
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_data() -> tuple[list[LlmDoc], dict[str, int]]:
|
||||
return mock_docs, mock_doc_mapping
|
||||
|
||||
|
||||
def process_text(
|
||||
tokens: list[str], mock_data: tuple[list[LlmDoc], dict[str, int]]
|
||||
) -> tuple[str, list[CitationInfo]]:
|
||||
mock_docs, mock_doc_id_to_rank_map = mock_data
|
||||
mapping = DocumentIdOrderMapping(order_mapping=mock_doc_id_to_rank_map)
|
||||
processor = CitationProcessor(
|
||||
context_docs=mock_docs,
|
||||
doc_id_to_rank_map=mapping,
|
||||
display_doc_order_dict=mock_doc_mapping_rerank,
|
||||
stop_stream=None,
|
||||
)
|
||||
|
||||
result: list[DanswerAnswerPiece | CitationInfo] = []
|
||||
for token in tokens:
|
||||
result.extend(processor.process_token(token))
|
||||
result.extend(processor.process_token(None))
|
||||
|
||||
final_answer_text = ""
|
||||
citations = []
|
||||
for piece in result:
|
||||
if isinstance(piece, DanswerAnswerPiece):
|
||||
final_answer_text += piece.answer_piece or ""
|
||||
elif isinstance(piece, CitationInfo):
|
||||
citations.append(piece)
|
||||
|
||||
return final_answer_text, citations
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_name, input_tokens, expected_text, expected_citations",
|
||||
[
|
||||
(
|
||||
"Single citation",
|
||||
["Gro", "wth! [", "1", "]", "."],
|
||||
"Growth! [[2]](https://0.com).",
|
||||
["doc_0"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_citation_substitution(
|
||||
mock_data: tuple[list[LlmDoc], dict[str, int]],
|
||||
test_name: str,
|
||||
input_tokens: list[str],
|
||||
expected_text: str,
|
||||
expected_citations: list[str],
|
||||
) -> None:
|
||||
final_answer_text, citations = process_text(input_tokens, mock_data)
|
||||
assert (
|
||||
final_answer_text.strip() == expected_text.strip()
|
||||
), f"Test '{test_name}' failed: Final answer text does not match expected output."
|
||||
assert [
|
||||
citation.document_id for citation in citations
|
||||
] == expected_citations, (
|
||||
f"Test '{test_name}' failed: Citations do not match expected output."
|
||||
)
|
||||
@@ -1,120 +0,0 @@
|
||||
from typing import List
|
||||
|
||||
from danswer.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from danswer.connectors.models import Document
|
||||
from danswer.connectors.models import DocumentSource
|
||||
from danswer.connectors.models import Section
|
||||
from danswer.indexing.indexing_pipeline import filter_documents
|
||||
|
||||
|
||||
def create_test_document(
|
||||
doc_id: str = "test_id",
|
||||
title: str | None = "Test Title",
|
||||
semantic_id: str = "test_semantic_id",
|
||||
sections: List[Section] | None = None,
|
||||
) -> Document:
|
||||
if sections is None:
|
||||
sections = [Section(text="Test content", link="test_link")]
|
||||
return Document(
|
||||
id=doc_id,
|
||||
title=title,
|
||||
semantic_identifier=semantic_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.FILE,
|
||||
metadata={},
|
||||
)
|
||||
|
||||
|
||||
def test_filter_documents_empty_title_and_content() -> None:
|
||||
doc = create_test_document(
|
||||
title="", semantic_id="", sections=[Section(text="", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_filter_documents_empty_title_with_content() -> None:
|
||||
doc = create_test_document(
|
||||
title="", sections=[Section(text="Valid content", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "test_id"
|
||||
|
||||
|
||||
def test_filter_documents_empty_content_with_title() -> None:
|
||||
doc = create_test_document(
|
||||
title="Valid Title", sections=[Section(text="", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "test_id"
|
||||
|
||||
|
||||
def test_filter_documents_exceeding_max_chars() -> None:
|
||||
if not MAX_DOCUMENT_CHARS: # Skip if no max chars configured
|
||||
return
|
||||
long_text = "a" * (MAX_DOCUMENT_CHARS + 1)
|
||||
doc = create_test_document(sections=[Section(text=long_text, link="test_link")])
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_filter_documents_valid_document() -> None:
|
||||
doc = create_test_document(
|
||||
title="Valid Title", sections=[Section(text="Valid content", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert result[0].id == "test_id"
|
||||
assert result[0].title == "Valid Title"
|
||||
|
||||
|
||||
def test_filter_documents_whitespace_only() -> None:
|
||||
doc = create_test_document(
|
||||
title=" ", semantic_id=" ", sections=[Section(text=" ", link="test_link")]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 0
|
||||
|
||||
|
||||
def test_filter_documents_semantic_id_no_title() -> None:
|
||||
doc = create_test_document(
|
||||
title=None,
|
||||
semantic_id="Valid Semantic ID",
|
||||
sections=[Section(text="Valid content", link="test_link")],
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert result[0].semantic_identifier == "Valid Semantic ID"
|
||||
|
||||
|
||||
def test_filter_documents_multiple_sections() -> None:
|
||||
doc = create_test_document(
|
||||
sections=[
|
||||
Section(text="Content 1", link="test_link"),
|
||||
Section(text="Content 2", link="test_link"),
|
||||
Section(text="Content 3", link="test_link"),
|
||||
]
|
||||
)
|
||||
result = filter_documents([doc])
|
||||
assert len(result) == 1
|
||||
assert len(result[0].sections) == 3
|
||||
|
||||
|
||||
def test_filter_documents_multiple_documents() -> None:
|
||||
docs = [
|
||||
create_test_document(doc_id="1", title="Title 1"),
|
||||
create_test_document(
|
||||
doc_id="2", title="", sections=[Section(text="", link="test_link")]
|
||||
), # Should be filtered
|
||||
create_test_document(doc_id="3", title="Title 3"),
|
||||
]
|
||||
result = filter_documents(docs)
|
||||
assert len(result) == 2
|
||||
assert {doc.id for doc in result} == {"1", "3"}
|
||||
|
||||
|
||||
def test_filter_documents_empty_batch() -> None:
|
||||
result = filter_documents([])
|
||||
assert len(result) == 0
|
||||
@@ -1,198 +0,0 @@
|
||||
import asyncio
|
||||
import time
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from httpx import AsyncClient
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
from model_server.encoders import CloudEmbedding
|
||||
from model_server.encoders import embed_text
|
||||
from model_server.encoders import local_rerank
|
||||
from model_server.encoders import process_embed_request
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mock_http_client() -> AsyncGenerator[AsyncMock, None]:
|
||||
with patch("httpx.AsyncClient") as mock:
|
||||
client = AsyncMock(spec=AsyncClient)
|
||||
mock.return_value = client
|
||||
client.post = AsyncMock()
|
||||
async with client as c:
|
||||
yield c
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_embeddings() -> List[List[float]]:
|
||||
return [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloud_embedding_context_manager() -> None:
|
||||
async with CloudEmbedding("fake-key", EmbeddingProvider.OPENAI) as embedding:
|
||||
assert not embedding._closed
|
||||
assert embedding._closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cloud_embedding_explicit_close() -> None:
|
||||
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
|
||||
assert not embedding._closed
|
||||
await embedding.aclose()
|
||||
assert embedding._closed
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_embedding(
|
||||
mock_http_client: AsyncMock, sample_embeddings: List[List[float]]
|
||||
) -> None:
|
||||
with patch("openai.AsyncOpenAI") as mock_openai:
|
||||
mock_client = AsyncMock()
|
||||
mock_openai.return_value = mock_client
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.data = [MagicMock(embedding=emb) for emb in sample_embeddings]
|
||||
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||
|
||||
embedding = CloudEmbedding("fake-key", EmbeddingProvider.OPENAI)
|
||||
result = await embedding._embed_openai(
|
||||
["test1", "test2"], "text-embedding-ada-002"
|
||||
)
|
||||
|
||||
assert result == sample_embeddings
|
||||
mock_client.embeddings.create.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_cloud_provider() -> None:
|
||||
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
|
||||
mock_embed.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_embed.side_effect = AsyncMock(return_value=[[0.1, 0.2], [0.3, 0.4]])
|
||||
|
||||
result = await embed_text(
|
||||
texts=["test1", "test2"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key="fake-key",
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_embed.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed_text_local_model() -> None:
|
||||
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.encode.return_value = [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
result = await embed_text(
|
||||
texts=["test1", "test2"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-local-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
assert result == [[0.1, 0.2], [0.3, 0.4]]
|
||||
mock_model.encode.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_local_rerank() -> None:
|
||||
with patch("model_server.encoders.get_local_reranking_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_array = MagicMock()
|
||||
mock_array.tolist.return_value = [0.8, 0.6]
|
||||
mock_model.predict.return_value = mock_array
|
||||
mock_get_model.return_value = mock_model
|
||||
|
||||
result = await local_rerank(
|
||||
query="test query", docs=["doc1", "doc2"], model_name="fake-rerank-model"
|
||||
)
|
||||
|
||||
assert result == [0.8, 0.6]
|
||||
mock_model.predict.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_handling() -> None:
|
||||
with patch("model_server.encoders.CloudEmbedding.embed") as mock_embed:
|
||||
mock_embed.side_effect = RateLimitError(
|
||||
"Rate limit exceeded", llm_provider="openai", model="fake-model"
|
||||
)
|
||||
|
||||
with pytest.raises(RateLimitError):
|
||||
await embed_text(
|
||||
texts=["test"],
|
||||
text_type=EmbedTextType.QUERY,
|
||||
model_name="fake-model",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key="fake-key",
|
||||
provider_type=EmbeddingProvider.OPENAI,
|
||||
prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_embeddings() -> None:
|
||||
def mock_encode(*args: Any, **kwargs: Any) -> List[List[float]]:
|
||||
time.sleep(5)
|
||||
return [[0.1, 0.2, 0.3]]
|
||||
|
||||
test_req = EmbedRequest(
|
||||
texts=["test"],
|
||||
model_name="'nomic-ai/nomic-embed-text-v1'",
|
||||
deployment_name=None,
|
||||
max_context_length=512,
|
||||
normalize_embeddings=True,
|
||||
api_key=None,
|
||||
provider_type=None,
|
||||
text_type=EmbedTextType.QUERY,
|
||||
manual_query_prefix=None,
|
||||
manual_passage_prefix=None,
|
||||
api_url=None,
|
||||
api_version=None,
|
||||
)
|
||||
|
||||
with patch("model_server.encoders.get_embedding_model") as mock_get_model:
|
||||
mock_model = MagicMock()
|
||||
mock_model.encode = mock_encode
|
||||
mock_get_model.return_value = mock_model
|
||||
start_time = time.time()
|
||||
|
||||
tasks = [process_embed_request(test_req) for _ in range(5)]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
end_time = time.time()
|
||||
|
||||
# 5 * 5 seconds = 25 seconds, this test ensures that the embeddings are at least yielding the thread
|
||||
# However, the developer may still introduce unnecessary blocking above the mock and this test will
|
||||
# still pass as long as it's less than (7 - 5) / 5 seconds
|
||||
assert end_time - start_time < 7
|
||||
2
ct.yaml
2
ct.yaml
@@ -6,7 +6,7 @@ chart-dirs:
|
||||
|
||||
# must be kept in sync with Chart.yaml
|
||||
chart-repos:
|
||||
- vespa=https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
- vespa=https://danswer-ai.github.io/vespa-helm-charts
|
||||
- postgresql=https://charts.bitnami.com/bitnami
|
||||
|
||||
helm-extra-args: --debug --timeout 600s
|
||||
|
||||
@@ -183,13 +183,6 @@ services:
|
||||
- GONG_CONNECTOR_START_TIME=${GONG_CONNECTOR_START_TIME:-}
|
||||
- NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP=${NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP:-}
|
||||
- GITHUB_CONNECTOR_BASE_URL=${GITHUB_CONNECTOR_BASE_URL:-}
|
||||
- MAX_DOCUMENT_CHARS=${MAX_DOCUMENT_CHARS:-}
|
||||
- MAX_FILE_SIZE_BYTES=${MAX_FILE_SIZE_BYTES:-}
|
||||
# Egnyte OAuth Configs
|
||||
- EGNYTE_CLIENT_ID=${EGNYTE_CLIENT_ID:-}
|
||||
- EGNYTE_CLIENT_SECRET=${EGNYTE_CLIENT_SECRET:-}
|
||||
- EGNYTE_BASE_DOMAIN=${EGNYTE_BASE_DOMAIN:-}
|
||||
- EGNYTE_LOCALHOST_OVERRIDE=${EGNYTE_LOCALHOST_OVERRIDE:-}
|
||||
# Celery Configs (defaults are set in the supervisord.conf file.
|
||||
# prefer doing that to have one source of defaults)
|
||||
- CELERY_WORKER_INDEXING_CONCURRENCY=${CELERY_WORKER_INDEXING_CONCURRENCY:-}
|
||||
|
||||
@@ -3,13 +3,13 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 14.3.1
|
||||
- name: vespa
|
||||
repository: https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
version: 0.2.18
|
||||
repository: https://danswer-ai.github.io/vespa-helm-charts
|
||||
version: 0.2.16
|
||||
- name: nginx
|
||||
repository: oci://registry-1.docker.io/bitnamicharts
|
||||
version: 15.14.0
|
||||
- name: redis
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
version: 20.1.0
|
||||
digest: sha256:5c9eb3d55d5f8e3beb64f26d26f686c8d62755daa10e2e6d87530bdf2fbbf957
|
||||
generated: "2024-12-10T10:47:35.812483-08:00"
|
||||
digest: sha256:711bbb76ba6ab604a36c9bf1839ab6faa5610afb21e535afd933c78f2d102232
|
||||
generated: "2024-11-07T09:39:30.17171-08:00"
|
||||
|
||||
@@ -23,8 +23,8 @@ dependencies:
|
||||
repository: https://charts.bitnami.com/bitnami
|
||||
condition: postgresql.enabled
|
||||
- name: vespa
|
||||
version: 0.2.18
|
||||
repository: https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
version: 0.2.16
|
||||
repository: https://danswer-ai.github.io/vespa-helm-charts
|
||||
condition: vespa.enabled
|
||||
- name: nginx
|
||||
version: 15.14.0
|
||||
|
||||
@@ -61,8 +61,6 @@ data:
|
||||
WEB_CONNECTOR_VALIDATE_URLS: ""
|
||||
GONG_CONNECTOR_START_TIME: ""
|
||||
NOTION_CONNECTOR_ENABLE_RECURSIVE_PAGE_LOOKUP: ""
|
||||
MAX_DOCUMENT_CHARS: ""
|
||||
MAX_FILE_SIZE_BYTES: ""
|
||||
# DanswerBot SlackBot Configs
|
||||
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: ""
|
||||
DANSWER_BOT_DISPLAY_ERROR_MSGS: ""
|
||||
|
||||
@@ -66,9 +66,6 @@ ARG NEXT_PUBLIC_POSTHOG_HOST
|
||||
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
|
||||
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
|
||||
|
||||
ARG NEXT_PUBLIC_CLOUD_ENABLED
|
||||
ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED}
|
||||
|
||||
ARG NEXT_PUBLIC_SENTRY_DSN
|
||||
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
|
||||
|
||||
@@ -141,9 +138,6 @@ ARG NEXT_PUBLIC_POSTHOG_HOST
|
||||
ENV NEXT_PUBLIC_POSTHOG_KEY=${NEXT_PUBLIC_POSTHOG_KEY}
|
||||
ENV NEXT_PUBLIC_POSTHOG_HOST=${NEXT_PUBLIC_POSTHOG_HOST}
|
||||
|
||||
ARG NEXT_PUBLIC_CLOUD_ENABLED
|
||||
ENV NEXT_PUBLIC_CLOUD_ENABLED=${NEXT_PUBLIC_CLOUD_ENABLED}
|
||||
|
||||
ARG NEXT_PUBLIC_SENTRY_DSN
|
||||
ENV NEXT_PUBLIC_SENTRY_DSN=${NEXT_PUBLIC_SENTRY_DSN}
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 12 KiB |
Binary file not shown.
|
Before Width: | Height: | Size: 769 KiB |
535
web/public/Wikipedia.svg
Normal file
535
web/public/Wikipedia.svg
Normal file
File diff suppressed because one or more lines are too long
|
After Width: | Height: | Size: 164 KiB |
@@ -2,11 +2,6 @@ import CardSection from "@/components/admin/CardSection";
|
||||
import { getNameFromPath } from "@/lib/fileUtils";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import Title from "@/components/ui/title";
|
||||
import { EditIcon } from "@/components/icons/icons";
|
||||
|
||||
import { useState } from "react";
|
||||
import { ChevronUpIcon } from "lucide-react";
|
||||
import { ChevronDownIcon } from "@/components/icons/icons";
|
||||
|
||||
function convertObjectToString(obj: any): string | any {
|
||||
// Check if obj is an object and not an array or null
|
||||
@@ -44,83 +39,14 @@ function buildConfigEntries(
|
||||
return obj;
|
||||
}
|
||||
|
||||
function ConfigItem({ label, value }: { label: string; value: any }) {
|
||||
const [isExpanded, setIsExpanded] = useState(false);
|
||||
const isExpandable = Array.isArray(value) && value.length > 5;
|
||||
|
||||
const renderValue = () => {
|
||||
if (Array.isArray(value)) {
|
||||
const displayedItems = isExpanded ? value : value.slice(0, 5);
|
||||
return (
|
||||
<ul className="list-disc max-w-full pl-4 mt-2 overflow-x-auto">
|
||||
{displayedItems.map((item, index) => (
|
||||
<li
|
||||
key={index}
|
||||
className="mb-1 max-w-full overflow-hidden text-right text-ellipsis whitespace-nowrap"
|
||||
>
|
||||
{convertObjectToString(item)}
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
);
|
||||
} else if (typeof value === "object" && value !== null) {
|
||||
return (
|
||||
<div className="mt-2 overflow-x-auto">
|
||||
{Object.entries(value).map(([key, val]) => (
|
||||
<div key={key} className="mb-1">
|
||||
<span className="font-semibold">{key}:</span>{" "}
|
||||
{convertObjectToString(val)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
return convertObjectToString(value) || "-";
|
||||
};
|
||||
|
||||
return (
|
||||
<li className="w-full py-2">
|
||||
<div className="flex items-center justify-between w-full">
|
||||
<span className="mb-2">{label}</span>
|
||||
<div className="mt-2 overflow-x-auto w-fit">
|
||||
{renderValue()}
|
||||
|
||||
{isExpandable && (
|
||||
<button
|
||||
onClick={() => setIsExpanded(!isExpanded)}
|
||||
className="mt-2 ml-auto text-text-600 hover:text-text-800 flex items-center"
|
||||
>
|
||||
{isExpanded ? (
|
||||
<>
|
||||
<ChevronUpIcon className="h-4 w-4 mr-1" />
|
||||
Show less
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<ChevronDownIcon className="h-4 w-4 mr-1" />
|
||||
Show all ({value.length} items)
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</li>
|
||||
);
|
||||
}
|
||||
|
||||
export function AdvancedConfigDisplay({
|
||||
pruneFreq,
|
||||
refreshFreq,
|
||||
indexingStart,
|
||||
onRefreshEdit,
|
||||
onPruningEdit,
|
||||
}: {
|
||||
pruneFreq: number | null;
|
||||
refreshFreq: number | null;
|
||||
indexingStart: Date | null;
|
||||
onRefreshEdit: () => void;
|
||||
onPruningEdit: () => void;
|
||||
}) {
|
||||
const formatRefreshFrequency = (seconds: number | null): string => {
|
||||
if (seconds === null) return "-";
|
||||
@@ -149,21 +75,14 @@ export function AdvancedConfigDisplay({
|
||||
<>
|
||||
<Title className="mt-8 mb-2">Advanced Configuration</Title>
|
||||
<CardSection>
|
||||
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
|
||||
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
|
||||
{pruneFreq && (
|
||||
<li
|
||||
key={0}
|
||||
className="w-full flex justify-between items-center py-2"
|
||||
>
|
||||
<span>Pruning Frequency</span>
|
||||
<span className="ml-auto w-24">
|
||||
{formatPruneFrequency(pruneFreq)}
|
||||
</span>
|
||||
<span className="w-8 text-right">
|
||||
<button onClick={() => onPruningEdit()}>
|
||||
<EditIcon size={12} />
|
||||
</button>
|
||||
</span>
|
||||
<span>{formatPruneFrequency(pruneFreq)}</span>
|
||||
</li>
|
||||
)}
|
||||
{refreshFreq && (
|
||||
@@ -172,14 +91,7 @@ export function AdvancedConfigDisplay({
|
||||
className="w-full flex justify-between items-center py-2"
|
||||
>
|
||||
<span>Refresh Frequency</span>
|
||||
<span className="ml-auto w-24">
|
||||
{formatRefreshFrequency(refreshFreq)}
|
||||
</span>
|
||||
<span className="w-8 text-right">
|
||||
<button onClick={() => onRefreshEdit()}>
|
||||
<EditIcon size={12} />
|
||||
</button>
|
||||
</span>
|
||||
<span>{formatRefreshFrequency(refreshFreq)}</span>
|
||||
</li>
|
||||
)}
|
||||
{indexingStart && (
|
||||
@@ -215,9 +127,15 @@ export function ConfigDisplay({
|
||||
<>
|
||||
<Title className="mb-2">Configuration</Title>
|
||||
<CardSection>
|
||||
<ul className="w-full text-sm divide-y divide-background-200 dark:divide-background-700">
|
||||
<ul className="w-full text-sm divide-y divide-neutral-200 dark:divide-neutral-700">
|
||||
{configEntries.map(([key, value]) => (
|
||||
<ConfigItem key={key} label={key} value={value} />
|
||||
<li
|
||||
key={key}
|
||||
className="w-full flex justify-between items-center py-2"
|
||||
>
|
||||
<span>{key}</span>
|
||||
<span>{convertObjectToString(value) || "-"}</span>
|
||||
</li>
|
||||
))}
|
||||
</ul>
|
||||
</CardSection>
|
||||
|
||||
@@ -7,10 +7,7 @@ import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { CCPairStatus } from "@/components/Status";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import CredentialSection from "@/components/credentials/CredentialSection";
|
||||
import {
|
||||
updateConnectorCredentialPairName,
|
||||
updateConnectorCredentialPairProperty,
|
||||
} from "@/lib/connector";
|
||||
import { updateConnectorCredentialPairName } from "@/lib/connector";
|
||||
import { credentialTemplates } from "@/lib/connectors/credentials";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
@@ -29,33 +26,12 @@ import { buildCCPairInfoUrl } from "./lib";
|
||||
import { CCPairFullInfo, ConnectorCredentialPairStatus } from "./types";
|
||||
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import EditPropertyModal from "@/components/modals/EditPropertyModal";
|
||||
|
||||
import * as Yup from "yup";
|
||||
|
||||
// since the uploaded files are cleaned up after some period of time
|
||||
// re-indexing will not work for the file connector. Also, it would not
|
||||
// make sense to re-index, since the files will not have changed.
|
||||
const CONNECTOR_TYPES_THAT_CANT_REINDEX: ValidSources[] = [ValidSources.File];
|
||||
|
||||
// synchronize these validations with the SQLAlchemy connector class until we have a
|
||||
// centralized schema for both frontend and backend
|
||||
const RefreshFrequencySchema = Yup.object().shape({
|
||||
propertyValue: Yup.number()
|
||||
.typeError("Property value must be a valid number")
|
||||
.integer("Property value must be an integer")
|
||||
.min(60, "Property value must be greater than or equal to 60")
|
||||
.required("Property value is required"),
|
||||
});
|
||||
|
||||
const PruneFrequencySchema = Yup.object().shape({
|
||||
propertyValue: Yup.number()
|
||||
.typeError("Property value must be a valid number")
|
||||
.integer("Property value must be an integer")
|
||||
.min(86400, "Property value must be greater than or equal to 86400")
|
||||
.required("Property value is required"),
|
||||
});
|
||||
|
||||
function Main({ ccPairId }: { ccPairId: number }) {
|
||||
const router = useRouter(); // Initialize the router
|
||||
const {
|
||||
@@ -69,8 +45,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
);
|
||||
|
||||
const [hasLoadedOnce, setHasLoadedOnce] = useState(false);
|
||||
const [editingRefreshFrequency, setEditingRefreshFrequency] = useState(false);
|
||||
const [editingPruningFrequency, setEditingPruningFrequency] = useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
const finishConnectorDeletion = useCallback(() => {
|
||||
@@ -116,86 +90,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
}
|
||||
};
|
||||
|
||||
const handleRefreshEdit = async () => {
|
||||
setEditingRefreshFrequency(true);
|
||||
};
|
||||
|
||||
const handlePruningEdit = async () => {
|
||||
setEditingPruningFrequency(true);
|
||||
};
|
||||
|
||||
const handleRefreshSubmit = async (
|
||||
propertyName: string,
|
||||
propertyValue: string
|
||||
) => {
|
||||
const parsedRefreshFreq = parseInt(propertyValue, 10);
|
||||
|
||||
if (isNaN(parsedRefreshFreq)) {
|
||||
setPopup({
|
||||
message: "Invalid refresh frequency: must be an integer",
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await updateConnectorCredentialPairProperty(
|
||||
ccPairId,
|
||||
propertyName,
|
||||
String(parsedRefreshFreq)
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
setPopup({
|
||||
message: "Connector refresh frequency updated successfully",
|
||||
type: "success",
|
||||
});
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: "Failed to update connector refresh frequency",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const handlePruningSubmit = async (
|
||||
propertyName: string,
|
||||
propertyValue: string
|
||||
) => {
|
||||
const parsedFreq = parseInt(propertyValue, 10);
|
||||
|
||||
if (isNaN(parsedFreq)) {
|
||||
setPopup({
|
||||
message: "Invalid pruning frequency: must be an integer",
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const response = await updateConnectorCredentialPairProperty(
|
||||
ccPairId,
|
||||
propertyName,
|
||||
String(parsedFreq)
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(await response.text());
|
||||
}
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
setPopup({
|
||||
message: "Connector pruning frequency updated successfully",
|
||||
type: "success",
|
||||
});
|
||||
} catch (error) {
|
||||
setPopup({
|
||||
message: "Failed to update connector pruning frequency",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
@@ -220,35 +114,9 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
refresh_freq: refreshFreq,
|
||||
indexing_start: indexingStart,
|
||||
} = ccPair.connector;
|
||||
|
||||
return (
|
||||
<>
|
||||
{popup}
|
||||
|
||||
{editingRefreshFrequency && (
|
||||
<EditPropertyModal
|
||||
propertyTitle="Refresh Frequency"
|
||||
propertyDetails="How often the connector should refresh (in seconds)"
|
||||
propertyName="refresh_frequency"
|
||||
propertyValue={String(refreshFreq)}
|
||||
validationSchema={RefreshFrequencySchema}
|
||||
onSubmit={handleRefreshSubmit}
|
||||
onClose={() => setEditingRefreshFrequency(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{editingPruningFrequency && (
|
||||
<EditPropertyModal
|
||||
propertyTitle="Pruning Frequency"
|
||||
propertyDetails="How often the connector should be pruned (in seconds)"
|
||||
propertyName="pruning_frequency"
|
||||
propertyValue={String(pruneFreq)}
|
||||
validationSchema={PruneFrequencySchema}
|
||||
onSubmit={handlePruningSubmit}
|
||||
onClose={() => setEditingPruningFrequency(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
<BackButton
|
||||
behaviorOverride={() => router.push("/admin/indexing/status")}
|
||||
/>
|
||||
@@ -257,7 +125,7 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
<SourceIcon iconSize={32} sourceType={ccPair.connector.source} />
|
||||
</div>
|
||||
|
||||
<div className="ml-1 overflow-hidden text-ellipsis whitespace-nowrap flex-1 mr-4">
|
||||
<div className="ml-1">
|
||||
<EditableStringFieldDisplay
|
||||
value={ccPair.name}
|
||||
isEditable={ccPair.is_editable_for_current_user}
|
||||
@@ -345,8 +213,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
pruneFreq={pruneFreq}
|
||||
indexingStart={indexingStart}
|
||||
refreshFreq={refreshFreq}
|
||||
onRefreshEdit={handleRefreshEdit}
|
||||
onPruningEdit={handlePruningEdit}
|
||||
/>
|
||||
)}
|
||||
|
||||
|
||||
@@ -49,8 +49,6 @@ import { useRouter } from "next/navigation";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
import { prepareOAuthAuthorizationRequest } from "@/lib/oauth_utils";
|
||||
import { EE_ENABLED, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import TemporaryLoadingModal from "@/components/TemporaryLoadingModal";
|
||||
import { getConnectorOauthRedirectUrl } from "@/lib/connectors/oauth";
|
||||
export interface AdvancedConfig {
|
||||
refreshFreq: number;
|
||||
pruneFreq: number;
|
||||
@@ -162,7 +160,6 @@ export default function AddConnector({
|
||||
// Form context and popup management
|
||||
const { setFormStep, setAllowCreate, formStep } = useFormContext();
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [uploading, setUploading] = useState(false);
|
||||
|
||||
// Hooks for Google Drive and Gmail credentials
|
||||
const { liveGDriveCredential } = useGoogleDriveCredentials(connector);
|
||||
@@ -340,24 +337,16 @@ export default function AddConnector({
|
||||
}
|
||||
// File-specific handling
|
||||
if (connector == "file") {
|
||||
setUploading(true);
|
||||
try {
|
||||
const response = await submitFiles(
|
||||
selectedFiles,
|
||||
setPopup,
|
||||
name,
|
||||
access_type,
|
||||
groups
|
||||
);
|
||||
if (response) {
|
||||
onSuccess();
|
||||
}
|
||||
} catch (error) {
|
||||
setPopup({ message: "Error uploading files", type: "error" });
|
||||
} finally {
|
||||
setUploading(false);
|
||||
const response = await submitFiles(
|
||||
selectedFiles,
|
||||
setPopup,
|
||||
name,
|
||||
access_type,
|
||||
groups
|
||||
);
|
||||
if (response) {
|
||||
onSuccess();
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -419,9 +408,9 @@ export default function AddConnector({
|
||||
<div className="mx-auto mb-8 w-full">
|
||||
{popup}
|
||||
|
||||
{uploading && (
|
||||
<TemporaryLoadingModal content="Uploading files..." />
|
||||
)}
|
||||
<div className="mb-4">
|
||||
<HealthCheckBanner />
|
||||
</div>
|
||||
|
||||
<AdminPageTitle
|
||||
includeDivider={false}
|
||||
@@ -453,19 +442,11 @@ export default function AddConnector({
|
||||
{/* Button to pop up a form to manually enter credentials */}
|
||||
<button
|
||||
className="mt-6 text-sm bg-background-900 px-2 py-1.5 flex text-text-200 flex-none rounded mr-4"
|
||||
onClick={async () => {
|
||||
const redirectUrl =
|
||||
await getConnectorOauthRedirectUrl(connector);
|
||||
// if redirect is supported, just use it
|
||||
if (redirectUrl) {
|
||||
window.location.href = redirectUrl;
|
||||
} else {
|
||||
setCreateConnectorToggle(
|
||||
(createConnectorToggle) =>
|
||||
!createConnectorToggle
|
||||
);
|
||||
}
|
||||
}}
|
||||
onClick={() =>
|
||||
setCreateConnectorToggle(
|
||||
(createConnectorToggle) => !createConnectorToggle
|
||||
)
|
||||
}
|
||||
>
|
||||
Create New
|
||||
</button>
|
||||
|
||||
@@ -104,9 +104,7 @@ const GDriveMain = ({}: {}) => {
|
||||
const googleDriveServiceAccountCredential:
|
||||
| Credential<GoogleDriveServiceAccountCredentialJson>
|
||||
| undefined = credentialsData.find(
|
||||
(credential) =>
|
||||
credential.credential_json?.google_service_account_key &&
|
||||
credential.source === "google_drive"
|
||||
(credential) => credential.credential_json?.google_service_account_key
|
||||
);
|
||||
|
||||
const googleDriveConnectorIndexingStatuses: ConnectorIndexingStatus<
|
||||
|
||||
@@ -135,7 +135,7 @@ export const DocumentFeedbackTable = ({
|
||||
/>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="relative">
|
||||
<div className="ml-auto flex w-16">
|
||||
<div
|
||||
key={document.document_id}
|
||||
className="h-10 ml-auto mr-8"
|
||||
|
||||
46
web/src/app/admin/prompt-library/hooks.ts
Normal file
46
web/src/app/admin/prompt-library/hooks.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
import useSWR from "swr";
|
||||
import { InputPrompt } from "./interfaces";
|
||||
|
||||
const fetcher = (url: string) => fetch(url).then((res) => res.json());
|
||||
|
||||
export const useAdminInputPrompts = () => {
|
||||
const { data, error, mutate } = useSWR<InputPrompt[]>(
|
||||
`/api/admin/input_prompt`,
|
||||
fetcher
|
||||
);
|
||||
|
||||
return {
|
||||
data,
|
||||
error,
|
||||
isLoading: !error && !data,
|
||||
refreshInputPrompts: mutate,
|
||||
};
|
||||
};
|
||||
|
||||
export const useInputPrompts = (includePublic: boolean = false) => {
|
||||
const { data, error, mutate } = useSWR<InputPrompt[]>(
|
||||
`/api/input_prompt${includePublic ? "?include_public=true" : ""}`,
|
||||
fetcher
|
||||
);
|
||||
|
||||
return {
|
||||
data,
|
||||
error,
|
||||
isLoading: !error && !data,
|
||||
refreshInputPrompts: mutate,
|
||||
};
|
||||
};
|
||||
|
||||
export const useInputPrompt = (id: number) => {
|
||||
const { data, error, mutate } = useSWR<InputPrompt>(
|
||||
`/api/input_prompt/${id}`,
|
||||
fetcher
|
||||
);
|
||||
|
||||
return {
|
||||
data,
|
||||
error,
|
||||
isLoading: !error && !data,
|
||||
refreshInputPrompt: mutate,
|
||||
};
|
||||
};
|
||||
31
web/src/app/admin/prompt-library/interfaces.ts
Normal file
31
web/src/app/admin/prompt-library/interfaces.ts
Normal file
@@ -0,0 +1,31 @@
|
||||
export interface InputPrompt {
|
||||
id: number;
|
||||
prompt: string;
|
||||
content: string;
|
||||
active: boolean;
|
||||
is_public: string;
|
||||
}
|
||||
|
||||
export interface EditPromptModalProps {
|
||||
onClose: () => void;
|
||||
|
||||
promptId: number;
|
||||
editInputPrompt: (
|
||||
promptId: number,
|
||||
values: CreateInputPromptRequest
|
||||
) => Promise<void>;
|
||||
}
|
||||
export interface CreateInputPromptRequest {
|
||||
prompt: string;
|
||||
content: string;
|
||||
}
|
||||
|
||||
export interface AddPromptModalProps {
|
||||
onClose: () => void;
|
||||
onSubmit: (promptData: CreateInputPromptRequest) => void;
|
||||
}
|
||||
export interface PromptData {
|
||||
id: number;
|
||||
prompt: string;
|
||||
content: string;
|
||||
}
|
||||
69
web/src/app/admin/prompt-library/modals/AddPromptModal.tsx
Normal file
69
web/src/app/admin/prompt-library/modals/AddPromptModal.tsx
Normal file
@@ -0,0 +1,69 @@
|
||||
import React from "react";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { Button } from "@/components/ui/button";
|
||||
|
||||
import { BookstackIcon } from "@/components/icons/icons";
|
||||
import { AddPromptModalProps } from "../interfaces";
|
||||
import { TextFormField } from "@/components/admin/connectors/Field";
|
||||
import { Modal } from "@/components/Modal";
|
||||
|
||||
const AddPromptSchema = Yup.object().shape({
|
||||
title: Yup.string().required("Title is required"),
|
||||
prompt: Yup.string().required("Prompt is required"),
|
||||
});
|
||||
|
||||
const AddPromptModal = ({ onClose, onSubmit }: AddPromptModalProps) => {
|
||||
return (
|
||||
<Modal onOutsideClick={onClose} width="w-full max-w-3xl">
|
||||
<Formik
|
||||
initialValues={{
|
||||
title: "",
|
||||
prompt: "",
|
||||
}}
|
||||
validationSchema={AddPromptSchema}
|
||||
onSubmit={(values, { setSubmitting }) => {
|
||||
onSubmit({
|
||||
prompt: values.title,
|
||||
content: values.prompt,
|
||||
});
|
||||
setSubmitting(false);
|
||||
onClose();
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, setFieldValue }) => (
|
||||
<Form>
|
||||
<h2 className="w-full text-2xl gap-x-2 text-emphasis font-bold mb-3 flex items-center">
|
||||
<BookstackIcon size={20} />
|
||||
Add prompt
|
||||
</h2>
|
||||
|
||||
<TextFormField
|
||||
label="Title"
|
||||
name="title"
|
||||
placeholder="Title (e.g. 'Reword')"
|
||||
/>
|
||||
|
||||
<TextFormField
|
||||
isTextArea
|
||||
label="Prompt"
|
||||
name="prompt"
|
||||
placeholder="Enter a prompt (e.g. 'help me rewrite the following politely and concisely for professional communication')"
|
||||
/>
|
||||
|
||||
<Button
|
||||
type="submit"
|
||||
className="w-full"
|
||||
disabled={isSubmitting}
|
||||
variant="submit"
|
||||
>
|
||||
Add prompt
|
||||
</Button>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default AddPromptModal;
|
||||
138
web/src/app/admin/prompt-library/modals/EditPromptModal.tsx
Normal file
138
web/src/app/admin/prompt-library/modals/EditPromptModal.tsx
Normal file
@@ -0,0 +1,138 @@
|
||||
import React from "react";
|
||||
import { Formik, Form, Field, ErrorMessage } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { Modal } from "@/components/Modal";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useInputPrompt } from "../hooks";
|
||||
import { EditPromptModalProps } from "../interfaces";
|
||||
|
||||
const EditPromptSchema = Yup.object().shape({
|
||||
prompt: Yup.string().required("Title is required"),
|
||||
content: Yup.string().required("Content is required"),
|
||||
active: Yup.boolean(),
|
||||
});
|
||||
|
||||
const EditPromptModal = ({
|
||||
onClose,
|
||||
promptId,
|
||||
editInputPrompt,
|
||||
}: EditPromptModalProps) => {
|
||||
const {
|
||||
data: promptData,
|
||||
error,
|
||||
refreshInputPrompt,
|
||||
} = useInputPrompt(promptId);
|
||||
|
||||
if (error)
|
||||
return (
|
||||
<Modal onOutsideClick={onClose} width="max-w-xl">
|
||||
<p>Failed to load prompt data</p>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
if (!promptData)
|
||||
return (
|
||||
<Modal onOutsideClick={onClose} width="w-full max-w-xl">
|
||||
<p>Loading...</p>
|
||||
</Modal>
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal onOutsideClick={onClose} width="w-full max-w-xl">
|
||||
<Formik
|
||||
initialValues={{
|
||||
prompt: promptData.prompt,
|
||||
content: promptData.content,
|
||||
active: promptData.active,
|
||||
}}
|
||||
validationSchema={EditPromptSchema}
|
||||
onSubmit={(values) => {
|
||||
editInputPrompt(promptId, values);
|
||||
refreshInputPrompt();
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values }) => (
|
||||
<Form className="items-stretch">
|
||||
<h2 className="text-2xl text-emphasis font-bold mb-3 flex items-center">
|
||||
<svg
|
||||
className="w-6 h-6 mr-2"
|
||||
viewBox="0 0 24 24"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path d="M3 17.25V21h3.75L17.81 9.94l-3.75-3.75L3 17.25zM20.71 7.04c.39-.39.39-1.02 0-1.41l-2.34-2.34c-.39-.39-1.02-.39-1.41 0l-1.83 1.83 3.75 3.75 1.83-1.83z" />
|
||||
</svg>
|
||||
Edit prompt
|
||||
</h2>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<label
|
||||
htmlFor="prompt"
|
||||
className="block text-sm font-medium mb-1"
|
||||
>
|
||||
Title
|
||||
</label>
|
||||
<Field
|
||||
as={Textarea}
|
||||
id="prompt"
|
||||
name="prompt"
|
||||
placeholder="Title (e.g. 'Draft email')"
|
||||
/>
|
||||
<ErrorMessage
|
||||
name="prompt"
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label
|
||||
htmlFor="content"
|
||||
className="block text-sm font-medium mb-1"
|
||||
>
|
||||
Content
|
||||
</label>
|
||||
<Field
|
||||
as={Textarea}
|
||||
id="content"
|
||||
name="content"
|
||||
placeholder="Enter prompt content (e.g. 'Write a professional-sounding email about the following content')"
|
||||
rows={4}
|
||||
/>
|
||||
<ErrorMessage
|
||||
name="content"
|
||||
component="div"
|
||||
className="text-red-500 text-sm mt-1"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="flex items-center">
|
||||
<Field type="checkbox" name="active" className="mr-2" />
|
||||
Active prompt
|
||||
</label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-6">
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={
|
||||
isSubmitting ||
|
||||
(values.prompt === promptData.prompt &&
|
||||
values.content === promptData.content &&
|
||||
values.active === promptData.active)
|
||||
}
|
||||
>
|
||||
{isSubmitting ? "Updating..." : "Update prompt"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Modal>
|
||||
);
|
||||
};
|
||||
|
||||
export default EditPromptModal;
|
||||
32
web/src/app/admin/prompt-library/page.tsx
Normal file
32
web/src/app/admin/prompt-library/page.tsx
Normal file
@@ -0,0 +1,32 @@
|
||||
"use client";
|
||||
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { ClosedBookIcon } from "@/components/icons/icons";
|
||||
import { useAdminInputPrompts } from "./hooks";
|
||||
import { PromptSection } from "./promptSection";
|
||||
|
||||
const Page = () => {
|
||||
const {
|
||||
data: promptLibrary,
|
||||
error: promptLibraryError,
|
||||
isLoading: promptLibraryIsLoading,
|
||||
refreshInputPrompts: refreshPrompts,
|
||||
} = useAdminInputPrompts();
|
||||
|
||||
return (
|
||||
<div className="container mx-auto">
|
||||
<AdminPageTitle
|
||||
icon={<ClosedBookIcon size={32} />}
|
||||
title="Prompt Library"
|
||||
/>
|
||||
<PromptSection
|
||||
promptLibrary={promptLibrary || []}
|
||||
isLoading={promptLibraryIsLoading}
|
||||
error={promptLibraryError}
|
||||
refreshPrompts={refreshPrompts}
|
||||
isPublic={true}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
export default Page;
|
||||
249
web/src/app/admin/prompt-library/promptLibrary.tsx
Normal file
249
web/src/app/admin/prompt-library/promptLibrary.tsx
Normal file
@@ -0,0 +1,249 @@
|
||||
"use client";
|
||||
|
||||
import { EditIcon, TrashIcon } from "@/components/icons/icons";
|
||||
import { PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
import { MagnifyingGlass } from "@phosphor-icons/react";
|
||||
import { useState } from "react";
|
||||
import {
|
||||
Table,
|
||||
TableHead,
|
||||
TableRow,
|
||||
TableBody,
|
||||
TableCell,
|
||||
} from "@/components/ui/table";
|
||||
import { FilterDropdown } from "@/components/search/filtering/FilterDropdown";
|
||||
import { FiTag } from "react-icons/fi";
|
||||
import { PageSelector } from "@/components/PageSelector";
|
||||
import { InputPrompt } from "./interfaces";
|
||||
import { DeleteEntityModal } from "@/components/modals/DeleteEntityModal";
|
||||
import { TableHeader } from "@/components/ui/table";
|
||||
|
||||
const CategoryBubble = ({
|
||||
name,
|
||||
onDelete,
|
||||
}: {
|
||||
name: string;
|
||||
onDelete?: () => void;
|
||||
}) => (
|
||||
<span
|
||||
className={`
|
||||
inline-block
|
||||
px-2
|
||||
py-1
|
||||
mr-1
|
||||
mb-1
|
||||
text-xs
|
||||
font-semibold
|
||||
text-emphasis
|
||||
bg-hover
|
||||
rounded-full
|
||||
items-center
|
||||
w-fit
|
||||
${onDelete ? "cursor-pointer" : ""}
|
||||
`}
|
||||
onClick={onDelete}
|
||||
>
|
||||
{name}
|
||||
{onDelete && (
|
||||
<button
|
||||
className="ml-1 text-subtle hover:text-emphasis"
|
||||
aria-label="Remove category"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
)}
|
||||
</span>
|
||||
);
|
||||
|
||||
const NUM_RESULTS_PER_PAGE = 10;
|
||||
|
||||
export const PromptLibraryTable = ({
|
||||
promptLibrary,
|
||||
refresh,
|
||||
setPopup,
|
||||
handleEdit,
|
||||
isPublic,
|
||||
}: {
|
||||
promptLibrary: InputPrompt[];
|
||||
refresh: () => void;
|
||||
setPopup: (popup: PopupSpec | null) => void;
|
||||
handleEdit: (promptId: number) => void;
|
||||
isPublic: boolean;
|
||||
}) => {
|
||||
const [query, setQuery] = useState("");
|
||||
const [currentPage, setCurrentPage] = useState(1);
|
||||
const [selectedStatus, setSelectedStatus] = useState<string[]>([]);
|
||||
|
||||
const columns = [
|
||||
{ name: "Prompt", key: "prompt" },
|
||||
{ name: "Content", key: "content" },
|
||||
{ name: "Status", key: "status" },
|
||||
{ name: "", key: "edit" },
|
||||
{ name: "", key: "delete" },
|
||||
];
|
||||
|
||||
const filteredPromptLibrary = promptLibrary.filter((item) => {
|
||||
const cleanedQuery = query.toLowerCase();
|
||||
const searchMatch =
|
||||
item.prompt.toLowerCase().includes(cleanedQuery) ||
|
||||
item.content.toLowerCase().includes(cleanedQuery);
|
||||
const statusMatch =
|
||||
selectedStatus.length === 0 ||
|
||||
(selectedStatus.includes("Active") && item.active) ||
|
||||
(selectedStatus.includes("Inactive") && !item.active);
|
||||
|
||||
return searchMatch && statusMatch;
|
||||
});
|
||||
|
||||
const totalPages = Math.ceil(
|
||||
filteredPromptLibrary.length / NUM_RESULTS_PER_PAGE
|
||||
);
|
||||
const startIndex = (currentPage - 1) * NUM_RESULTS_PER_PAGE;
|
||||
const endIndex = startIndex + NUM_RESULTS_PER_PAGE;
|
||||
const paginatedPromptLibrary = filteredPromptLibrary.slice(
|
||||
startIndex,
|
||||
endIndex
|
||||
);
|
||||
|
||||
const handlePageChange = (page: number) => {
|
||||
setCurrentPage(page);
|
||||
};
|
||||
|
||||
const handleDelete = async (id: number) => {
|
||||
const response = await fetch(
|
||||
`/api${isPublic ? "/admin" : ""}/input_prompt/${id}`,
|
||||
{
|
||||
method: "DELETE",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
setPopup({ message: "Failed to delete input prompt", type: "error" });
|
||||
}
|
||||
refresh();
|
||||
setConfirmDeletionId(null);
|
||||
};
|
||||
|
||||
const handleStatusSelect = (status: string) => {
|
||||
setSelectedStatus((prev) => {
|
||||
if (prev.includes(status)) {
|
||||
return prev.filter((s) => s !== status);
|
||||
}
|
||||
return [...prev, status];
|
||||
});
|
||||
};
|
||||
|
||||
const [confirmDeletionId, setConfirmDeletionId] = useState<number | null>(
|
||||
null
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="justify-center py-2">
|
||||
{confirmDeletionId != null && (
|
||||
<DeleteEntityModal
|
||||
onClose={() => setConfirmDeletionId(null)}
|
||||
onSubmit={() => handleDelete(confirmDeletionId)}
|
||||
entityType="prompt"
|
||||
entityName={
|
||||
paginatedPromptLibrary.find(
|
||||
(prompt) => prompt.id === confirmDeletionId
|
||||
)?.prompt ?? ""
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="flex items-center w-full border-2 border-border rounded-lg px-4 py-2 focus-within:border-accent">
|
||||
<MagnifyingGlass />
|
||||
<input
|
||||
className="flex-grow ml-2 bg-transparent outline-none placeholder-subtle"
|
||||
placeholder="Find prompts..."
|
||||
value={query}
|
||||
onChange={(event) => {
|
||||
setQuery(event.target.value);
|
||||
setCurrentPage(1);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="my-4 border-b border-border">
|
||||
<FilterDropdown
|
||||
options={[
|
||||
{ key: "Active", display: "Active" },
|
||||
{ key: "Inactive", display: "Inactive" },
|
||||
]}
|
||||
selected={selectedStatus}
|
||||
handleSelect={(option) => handleStatusSelect(option.key)}
|
||||
icon={<FiTag size={16} />}
|
||||
defaultDisplay="All Statuses"
|
||||
/>
|
||||
<div className="flex flex-col items-stretch w-full flex-wrap pb-4 mt-3">
|
||||
{selectedStatus.map((status) => (
|
||||
<CategoryBubble
|
||||
key={status}
|
||||
name={status}
|
||||
onDelete={() => handleStatusSelect(status)}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</div>
|
||||
<div className="mx-auto overflow-x-auto">
|
||||
<Table>
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
{columns.map((column) => (
|
||||
<TableHead key={column.key}>{column.name}</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{paginatedPromptLibrary.length > 0 ? (
|
||||
paginatedPromptLibrary
|
||||
.filter((prompt) => !(!isPublic && prompt.is_public))
|
||||
.map((item) => (
|
||||
<TableRow key={item.id}>
|
||||
<TableCell>{item.prompt}</TableCell>
|
||||
<TableCell
|
||||
className="
|
||||
max-w-xs
|
||||
overflow-hidden
|
||||
text-ellipsis
|
||||
break-words
|
||||
"
|
||||
>
|
||||
{item.content}
|
||||
</TableCell>
|
||||
<TableCell>{item.active ? "Active" : "Inactive"}</TableCell>
|
||||
<TableCell>
|
||||
<button
|
||||
className="cursor-pointer"
|
||||
onClick={() => setConfirmDeletionId(item.id)}
|
||||
>
|
||||
<TrashIcon size={20} />
|
||||
</button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<button onClick={() => handleEdit(item.id)}>
|
||||
<EditIcon size={12} />
|
||||
</button>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))
|
||||
) : (
|
||||
<TableRow>
|
||||
<TableCell colSpan={6}>No matching prompts found...</TableCell>
|
||||
</TableRow>
|
||||
)}
|
||||
</TableBody>
|
||||
</Table>
|
||||
{paginatedPromptLibrary.length > 0 && (
|
||||
<div className="mt-4 flex justify-center">
|
||||
<PageSelector
|
||||
currentPage={currentPage}
|
||||
totalPages={totalPages}
|
||||
onPageChange={handlePageChange}
|
||||
shouldScroll={true}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
150
web/src/app/admin/prompt-library/promptSection.tsx
Normal file
150
web/src/app/admin/prompt-library/promptSection.tsx
Normal file
@@ -0,0 +1,150 @@
|
||||
"use client";
|
||||
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Separator } from "@/components/ui/separator";
|
||||
import Text from "@/components/ui/text";
|
||||
import { useState } from "react";
|
||||
import AddPromptModal from "./modals/AddPromptModal";
|
||||
import EditPromptModal from "./modals/EditPromptModal";
|
||||
import { PromptLibraryTable } from "./promptLibrary";
|
||||
import { CreateInputPromptRequest, InputPrompt } from "./interfaces";
|
||||
|
||||
export const PromptSection = ({
|
||||
promptLibrary,
|
||||
isLoading,
|
||||
error,
|
||||
refreshPrompts,
|
||||
centering = false,
|
||||
isPublic,
|
||||
}: {
|
||||
promptLibrary: InputPrompt[];
|
||||
isLoading: boolean;
|
||||
error: any;
|
||||
refreshPrompts: () => void;
|
||||
centering?: boolean;
|
||||
isPublic: boolean;
|
||||
}) => {
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [newPrompt, setNewPrompt] = useState(false);
|
||||
const [newPromptId, setNewPromptId] = useState<number | null>(null);
|
||||
|
||||
const createInputPrompt = async (
|
||||
promptData: CreateInputPromptRequest
|
||||
): Promise<InputPrompt> => {
|
||||
const response = await fetch("/api/input_prompt", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ ...promptData, is_public: isPublic }),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
setPopup({ message: "Failed to create input prompt", type: "error" });
|
||||
}
|
||||
|
||||
refreshPrompts();
|
||||
return response.json();
|
||||
};
|
||||
|
||||
const editInputPrompt = async (
|
||||
promptId: number,
|
||||
values: CreateInputPromptRequest
|
||||
) => {
|
||||
try {
|
||||
const response = await fetch(`/api/input_prompt/${promptId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(values),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
setPopup({ message: "Failed to update prompt!", type: "error" });
|
||||
}
|
||||
|
||||
setNewPromptId(null);
|
||||
refreshPrompts();
|
||||
} catch (err) {
|
||||
setPopup({ message: `Failed to update prompt: ${err}`, type: "error" });
|
||||
}
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (error || !promptLibrary) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Error loading standard answers"
|
||||
errorMsg={error?.info?.message || error?.message?.info?.detail}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const handleEdit = (promptId: number) => {
|
||||
setNewPromptId(promptId);
|
||||
};
|
||||
|
||||
return (
|
||||
<div
|
||||
className={`w-full ${
|
||||
centering ? "flex-col flex justify-center" : ""
|
||||
} mb-8`}
|
||||
>
|
||||
{popup}
|
||||
|
||||
{newPrompt && (
|
||||
<AddPromptModal
|
||||
onSubmit={createInputPrompt}
|
||||
onClose={() => setNewPrompt(false)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{newPromptId && (
|
||||
<EditPromptModal
|
||||
promptId={newPromptId}
|
||||
editInputPrompt={editInputPrompt}
|
||||
onClose={() => setNewPromptId(null)}
|
||||
/>
|
||||
)}
|
||||
<div className={centering ? "max-w-sm mx-auto" : ""}>
|
||||
<Text className="mb-2 my-auto">
|
||||
Create prompts that can be accessed with the <i>`/`</i> shortcut in
|
||||
Danswer Chat.{" "}
|
||||
{isPublic
|
||||
? "Prompts created here will be accessible to all users."
|
||||
: "Prompts created here will be available only to you."}
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<div className="mb-2"></div>
|
||||
|
||||
<Button
|
||||
onClick={() => setNewPrompt(true)}
|
||||
className={centering ? "mx-auto" : ""}
|
||||
variant="navigate"
|
||||
size="sm"
|
||||
>
|
||||
New Prompt
|
||||
</Button>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div>
|
||||
<PromptLibraryTable
|
||||
isPublic={isPublic}
|
||||
promptLibrary={promptLibrary}
|
||||
setPopup={setPopup}
|
||||
refresh={refreshPrompts}
|
||||
handleEdit={handleEdit}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,13 +1,13 @@
|
||||
"use client";
|
||||
import { useEffect, useState } from "react";
|
||||
import { Tabs, TabsContent, TabsList, TabsTrigger } from "@/components/ui/tabs";
|
||||
import { Card, CardContent, CardHeader, CardTitle } from "@/components/ui/card";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import InvitedUserTable from "@/components/admin/users/InvitedUserTable";
|
||||
import SignedUpUserTable from "@/components/admin/users/SignedUpUserTable";
|
||||
import { SearchBar } from "@/components/search/SearchBar";
|
||||
import { useState } from "react";
|
||||
import { FiPlusSquare } from "react-icons/fi";
|
||||
import { Modal } from "@/components/Modal";
|
||||
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Text from "@/components/ui/text";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { usePopup, PopupSpec } from "@/components/admin/connectors/Popup";
|
||||
@@ -15,10 +15,42 @@ import { UsersIcon } from "@/components/icons/icons";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { HidableSection } from "@/app/admin/assistants/HidableSection";
|
||||
import BulkAdd from "@/components/admin/users/BulkAdd";
|
||||
import { UsersResponse } from "@/lib/users/interfaces";
|
||||
import SlackUserTable from "@/components/admin/users/SlackUserTable";
|
||||
import Text from "@/components/ui/text";
|
||||
|
||||
const ValidDomainsDisplay = ({ validDomains }: { validDomains: string[] }) => {
|
||||
if (!validDomains.length) {
|
||||
return (
|
||||
<div className="text-sm">
|
||||
No invited users. Anyone can sign up with a valid email address. To
|
||||
restrict access you can:
|
||||
<div className="flex flex-wrap ml-2 mt-1">
|
||||
(1) Invite users above. Once a user has been invited, only emails that
|
||||
have explicitly been invited will be able to sign-up.
|
||||
</div>
|
||||
<div className="mt-1 ml-2">
|
||||
(2) Set the{" "}
|
||||
<b className="font-mono w-fit h-fit">VALID_EMAIL_DOMAINS</b>{" "}
|
||||
environment variable to a comma separated list of email domains. This
|
||||
will restrict access to users with email addresses from these domains.
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="text-sm">
|
||||
No invited users. Anyone with an email address with any of the following
|
||||
domains can sign up: <i>{validDomains.join(", ")}</i>.
|
||||
<div className="mt-2">
|
||||
To further restrict access you can invite users above. Once a user has
|
||||
been invited, only emails that have explicitly been invited will be able
|
||||
to sign-up.
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
const UsersTables = ({
|
||||
q,
|
||||
@@ -29,48 +61,23 @@ const UsersTables = ({
|
||||
}) => {
|
||||
const [invitedPage, setInvitedPage] = useState(1);
|
||||
const [acceptedPage, setAcceptedPage] = useState(1);
|
||||
const [slackUsersPage, setSlackUsersPage] = useState(1);
|
||||
|
||||
const [usersData, setUsersData] = useState<UsersResponse | undefined>(
|
||||
undefined
|
||||
);
|
||||
const [domainsData, setDomainsData] = useState<string[] | undefined>(
|
||||
undefined
|
||||
);
|
||||
|
||||
const { data, error, mutate } = useSWR<UsersResponse>(
|
||||
`/api/manage/users?q=${encodeURIComponent(q)}&accepted_page=${
|
||||
const { data, isLoading, mutate, error } = useSWR<UsersResponse>(
|
||||
`/api/manage/users?q=${encodeURI(q)}&accepted_page=${
|
||||
acceptedPage - 1
|
||||
}&invited_page=${invitedPage - 1}&slack_users_page=${slackUsersPage - 1}`,
|
||||
}&invited_page=${invitedPage - 1}`,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
const {
|
||||
data: validDomains,
|
||||
isLoading: isLoadingDomains,
|
||||
error: domainsError,
|
||||
} = useSWR<string[]>("/api/manage/admin/valid-domains", errorHandlingFetcher);
|
||||
|
||||
const { data: validDomains, error: domainsError } = useSWR<string[]>(
|
||||
"/api/manage/admin/valid-domains",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
if (data) {
|
||||
setUsersData(data);
|
||||
}
|
||||
}, [data]);
|
||||
|
||||
useEffect(() => {
|
||||
if (validDomains) {
|
||||
setDomainsData(validDomains);
|
||||
}
|
||||
}, [validDomains]);
|
||||
|
||||
const activeData = data ?? usersData;
|
||||
const activeDomains = validDomains ?? domainsData;
|
||||
|
||||
// Show loading animation only during the initial data fetch
|
||||
if (!activeData || !activeDomains) {
|
||||
if (isLoading || isLoadingDomains) {
|
||||
return <LoadingAnimation text="Loading" />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
if (error || !data) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Error loading users"
|
||||
@@ -79,7 +86,7 @@ const UsersTables = ({
|
||||
);
|
||||
}
|
||||
|
||||
if (domainsError) {
|
||||
if (domainsError || !validDomains) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Error loading valid domains"
|
||||
@@ -88,94 +95,45 @@ const UsersTables = ({
|
||||
);
|
||||
}
|
||||
|
||||
const {
|
||||
accepted,
|
||||
invited,
|
||||
accepted_pages,
|
||||
invited_pages,
|
||||
slack_users,
|
||||
slack_users_pages,
|
||||
} = activeData;
|
||||
const { accepted, invited, accepted_pages, invited_pages } = data;
|
||||
|
||||
// remove users that are already accepted
|
||||
const finalInvited = invited.filter(
|
||||
(user) => !accepted.some((u) => u.email === user.email)
|
||||
(user) => !accepted.map((u) => u.email).includes(user.email)
|
||||
);
|
||||
|
||||
return (
|
||||
<Tabs defaultValue="invited">
|
||||
<TabsList>
|
||||
<TabsTrigger value="invited">Invited Users</TabsTrigger>
|
||||
<TabsTrigger value="current">Current Users</TabsTrigger>
|
||||
<TabsTrigger value="danswerbot">DanswerBot Users</TabsTrigger>
|
||||
</TabsList>
|
||||
|
||||
<TabsContent value="invited">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Invited Users</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{finalInvited.length > 0 ? (
|
||||
<InvitedUserTable
|
||||
users={finalInvited}
|
||||
setPopup={setPopup}
|
||||
currentPage={invitedPage}
|
||||
onPageChange={setInvitedPage}
|
||||
totalPages={invited_pages}
|
||||
mutate={mutate}
|
||||
/>
|
||||
) : (
|
||||
<p>Users that have been invited will show up here</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="current">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>Current Users</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{accepted.length > 0 ? (
|
||||
<SignedUpUserTable
|
||||
users={accepted}
|
||||
setPopup={setPopup}
|
||||
currentPage={acceptedPage}
|
||||
onPageChange={setAcceptedPage}
|
||||
totalPages={accepted_pages}
|
||||
mutate={mutate}
|
||||
/>
|
||||
) : (
|
||||
<p>Users that have an account will show up here</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
|
||||
<TabsContent value="danswerbot">
|
||||
<Card>
|
||||
<CardHeader>
|
||||
<CardTitle>DanswerBot Users</CardTitle>
|
||||
</CardHeader>
|
||||
<CardContent>
|
||||
{slack_users.length > 0 ? (
|
||||
<SlackUserTable
|
||||
setPopup={setPopup}
|
||||
currentPage={slackUsersPage}
|
||||
onPageChange={setSlackUsersPage}
|
||||
totalPages={slack_users_pages}
|
||||
invitedUsers={finalInvited}
|
||||
slackusers={slack_users}
|
||||
mutate={mutate}
|
||||
/>
|
||||
) : (
|
||||
<p>Slack-only users will show up here</p>
|
||||
)}
|
||||
</CardContent>
|
||||
</Card>
|
||||
</TabsContent>
|
||||
</Tabs>
|
||||
<>
|
||||
<HidableSection sectionTitle="Invited Users">
|
||||
{invited.length > 0 ? (
|
||||
finalInvited.length > 0 ? (
|
||||
<InvitedUserTable
|
||||
users={finalInvited}
|
||||
setPopup={setPopup}
|
||||
currentPage={invitedPage}
|
||||
onPageChange={setInvitedPage}
|
||||
totalPages={invited_pages}
|
||||
mutate={mutate}
|
||||
/>
|
||||
) : (
|
||||
<div className="text-sm">
|
||||
To invite additional teammates, use the <b>Invite Users</b> button
|
||||
above!
|
||||
</div>
|
||||
)
|
||||
) : (
|
||||
<ValidDomainsDisplay validDomains={validDomains} />
|
||||
)}
|
||||
</HidableSection>
|
||||
<SignedUpUserTable
|
||||
users={accepted}
|
||||
setPopup={setPopup}
|
||||
currentPage={acceptedPage}
|
||||
onPageChange={setAcceptedPage}
|
||||
totalPages={accepted_pages}
|
||||
mutate={mutate}
|
||||
/>
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
@@ -257,7 +215,6 @@ const Page = () => {
|
||||
return (
|
||||
<div className="mx-auto container">
|
||||
<AdminPageTitle title="Manage Users" icon={<UsersIcon size={32} />} />
|
||||
|
||||
<SearchableTables />
|
||||
</div>
|
||||
);
|
||||
|
||||
51
web/src/app/assistants/mine/WrappedInputPrompts.tsx
Normal file
51
web/src/app/assistants/mine/WrappedInputPrompts.tsx
Normal file
@@ -0,0 +1,51 @@
|
||||
"use client";
|
||||
import SidebarWrapper from "../SidebarWrapper";
|
||||
import { ChatSession } from "@/app/chat/interfaces";
|
||||
import { Folder } from "@/app/chat/folders/interfaces";
|
||||
import { User } from "@/lib/types";
|
||||
|
||||
import { AssistantsPageTitle } from "../AssistantsPageTitle";
|
||||
import { useInputPrompts } from "@/app/admin/prompt-library/hooks";
|
||||
import { PromptSection } from "@/app/admin/prompt-library/promptSection";
|
||||
|
||||
export default function WrappedPrompts({
|
||||
chatSessions,
|
||||
initiallyToggled,
|
||||
folders,
|
||||
openedFolders,
|
||||
}: {
|
||||
chatSessions: ChatSession[];
|
||||
folders: Folder[];
|
||||
initiallyToggled: boolean;
|
||||
openedFolders?: { [key: number]: boolean };
|
||||
}) {
|
||||
const {
|
||||
data: promptLibrary,
|
||||
error: promptLibraryError,
|
||||
isLoading: promptLibraryIsLoading,
|
||||
refreshInputPrompts: refreshPrompts,
|
||||
} = useInputPrompts(false);
|
||||
|
||||
return (
|
||||
<SidebarWrapper
|
||||
size="lg"
|
||||
page="chat"
|
||||
initiallyToggled={initiallyToggled}
|
||||
chatSessions={chatSessions}
|
||||
folders={folders}
|
||||
openedFolders={openedFolders}
|
||||
>
|
||||
<div className="mx-auto w-searchbar-xs 2xl:w-searchbar-sm 3xl:w-searchbar">
|
||||
<AssistantsPageTitle>Prompt Gallery</AssistantsPageTitle>
|
||||
<PromptSection
|
||||
promptLibrary={promptLibrary || []}
|
||||
isLoading={promptLibraryIsLoading}
|
||||
error={promptLibraryError}
|
||||
refreshPrompts={refreshPrompts}
|
||||
isPublic={false}
|
||||
centering
|
||||
/>
|
||||
</div>
|
||||
</SidebarWrapper>
|
||||
);
|
||||
}
|
||||
@@ -60,11 +60,7 @@ export function ChatBanner() {
|
||||
<div className={`flex justify-center w-full overflow-hidden pr-8`}>
|
||||
<div
|
||||
ref={contentRef}
|
||||
className={`overflow-hidden ${
|
||||
settings.enterpriseSettings.two_lines_for_chat_header
|
||||
? "line-clamp-2"
|
||||
: "line-clamp-1"
|
||||
} text-center max-w-full`}
|
||||
className={`overflow-hidden ${settings.enterpriseSettings.two_lines_for_chat_header ? "line-clamp-2" : "line-clamp-1"} text-center max-w-full`}
|
||||
>
|
||||
<MinimalMarkdown
|
||||
className="prose text-sm max-w-full"
|
||||
@@ -75,11 +71,7 @@ export function ChatBanner() {
|
||||
<div className="absolute top-0 left-0 invisible flex justify-center max-w-full">
|
||||
<div
|
||||
ref={fullContentRef}
|
||||
className={`overflow-hidden invisible ${
|
||||
settings.enterpriseSettings.two_lines_for_chat_header
|
||||
? "line-clamp-2"
|
||||
: "line-clamp-1"
|
||||
} text-center max-w-full`}
|
||||
className={`overflow-hidden invisible ${settings.enterpriseSettings.two_lines_for_chat_header ? "line-clamp-2" : "line-clamp-1"} text-center max-w-full`}
|
||||
>
|
||||
<MinimalMarkdown
|
||||
className="prose text-sm max-w-full"
|
||||
|
||||
@@ -135,6 +135,7 @@ export function ChatPage({
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
userInputPrompts,
|
||||
defaultAssistantId,
|
||||
shouldShowWelcomeModal,
|
||||
refreshChatSessions,
|
||||
@@ -470,14 +471,13 @@ export function ChatPage({
|
||||
loadedSessionId != null) &&
|
||||
!currentChatAnswering()
|
||||
) {
|
||||
updateCompleteMessageDetail(chatSession.chat_session_id, newMessageMap);
|
||||
|
||||
const latestMessageId =
|
||||
newMessageHistory[newMessageHistory.length - 1]?.messageId;
|
||||
|
||||
setSelectedMessageForDocDisplay(
|
||||
latestMessageId !== undefined ? latestMessageId : null
|
||||
);
|
||||
|
||||
updateCompleteMessageDetail(chatSession.chat_session_id, newMessageMap);
|
||||
}
|
||||
|
||||
setChatSessionSharedStatus(chatSession.shared_status);
|
||||
@@ -976,7 +976,6 @@ export function ChatPage({
|
||||
|
||||
useEffect(() => {
|
||||
if (
|
||||
!personaIncludesRetrieval &&
|
||||
(!selectedDocuments || selectedDocuments.length === 0) &&
|
||||
documentSidebarToggled &&
|
||||
!filtersToggled
|
||||
@@ -1080,17 +1079,10 @@ export function ChatPage({
|
||||
updateCanContinue(false, frozenSessionId);
|
||||
|
||||
if (currentChatState() != "input") {
|
||||
if (currentChatState() == "uploading") {
|
||||
setPopup({
|
||||
message: "Please wait for the content to upload",
|
||||
type: "error",
|
||||
});
|
||||
} else {
|
||||
setPopup({
|
||||
message: "Please wait for the response to complete",
|
||||
type: "error",
|
||||
});
|
||||
}
|
||||
setPopup({
|
||||
message: "Please wait for the response to complete",
|
||||
type: "error",
|
||||
});
|
||||
|
||||
return;
|
||||
}
|
||||
@@ -1565,7 +1557,7 @@ export function ChatPage({
|
||||
}
|
||||
};
|
||||
|
||||
const handleImageUpload = async (acceptedFiles: File[]) => {
|
||||
const handleImageUpload = (acceptedFiles: File[]) => {
|
||||
const [_, llmModel] = getFinalLLM(
|
||||
llmProviders,
|
||||
liveAssistant,
|
||||
@@ -1605,9 +1597,8 @@ export function ChatPage({
|
||||
(file) => !tempFileDescriptors.some((newFile) => newFile.id === file.id)
|
||||
);
|
||||
};
|
||||
updateChatState("uploading", currentSessionId());
|
||||
|
||||
await uploadFilesForChat(acceptedFiles).then(([files, error]) => {
|
||||
uploadFilesForChat(acceptedFiles).then(([files, error]) => {
|
||||
if (error) {
|
||||
setCurrentMessageFiles((prev) => removeTempFiles(prev));
|
||||
setPopup({
|
||||
@@ -1618,7 +1609,6 @@ export function ChatPage({
|
||||
setCurrentMessageFiles((prev) => [...removeTempFiles(prev), ...files]);
|
||||
}
|
||||
});
|
||||
updateChatState("input", currentSessionId());
|
||||
};
|
||||
const [showHistorySidebar, setShowHistorySidebar] = useState(false); // State to track if sidebar is open
|
||||
|
||||
@@ -2241,7 +2231,7 @@ export function ChatPage({
|
||||
ref={scrollableDivRef}
|
||||
>
|
||||
{liveAssistant && onAssistantChange && (
|
||||
<div className="z-20 fixed top-0 pointer-events-none left-0 w-full flex justify-center overflow-visible">
|
||||
<div className="z-20 fixed top-4 pointer-events-none left-0 w-full flex justify-center overflow-visible">
|
||||
{!settings?.isMobile && (
|
||||
<div
|
||||
style={{ transition: "width 0.30s ease-out" }}
|
||||
@@ -2754,6 +2744,7 @@ export function ChatPage({
|
||||
chatState={currentSessionChatState}
|
||||
stopGenerating={stopGenerating}
|
||||
openModelSettings={() => setSettingsToggled(true)}
|
||||
inputPrompts={userInputPrompts}
|
||||
showDocs={() => setDocumentSelection(true)}
|
||||
selectedDocuments={selectedDocuments}
|
||||
// assistant stuff
|
||||
|
||||
@@ -85,7 +85,7 @@ export const ChatFilters = forwardRef<HTMLDivElement, ChatFiltersProps>(
|
||||
return (
|
||||
<div
|
||||
id="danswer-chat-sidebar"
|
||||
className={`relative max-w-full ${
|
||||
className={`relative py-2 max-w-full ${
|
||||
!modal ? "border-l h-full border-sidebar-border" : ""
|
||||
}`}
|
||||
onClick={(e) => {
|
||||
|
||||
@@ -2,7 +2,7 @@ import React, { useContext, useEffect, useRef, useState } from "react";
|
||||
import { FiPlusCircle, FiPlus, FiInfo, FiX, FiSearch } from "react-icons/fi";
|
||||
import { ChatInputOption } from "./ChatInputOption";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
|
||||
import { InputPrompt } from "@/app/admin/prompt-library/interfaces";
|
||||
import { FilterManager, LlmOverrideManager } from "@/lib/hooks";
|
||||
import { SelectedFilterDisplay } from "./SelectedFilterDisplay";
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
@@ -58,6 +58,7 @@ interface ChatInputBarProps {
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
chatState: ChatState;
|
||||
alternativeAssistant: Persona | null;
|
||||
inputPrompts: InputPrompt[];
|
||||
// assistants
|
||||
selectedAssistant: Persona;
|
||||
setSelectedAssistant: (assistant: Persona) => void;
|
||||
@@ -97,6 +98,7 @@ export function ChatInputBar({
|
||||
textAreaRef,
|
||||
alternativeAssistant,
|
||||
chatSessionId,
|
||||
inputPrompts,
|
||||
toggleFilters,
|
||||
}: ChatInputBarProps) {
|
||||
useEffect(() => {
|
||||
@@ -135,6 +137,7 @@ export function ChatInputBar({
|
||||
|
||||
const suggestionsRef = useRef<HTMLDivElement | null>(null);
|
||||
const [showSuggestions, setShowSuggestions] = useState(false);
|
||||
const [showPrompts, setShowPrompts] = useState(false);
|
||||
|
||||
const interactionsRef = useRef<HTMLDivElement | null>(null);
|
||||
|
||||
@@ -143,6 +146,19 @@ export function ChatInputBar({
|
||||
setTabbingIconIndex(0);
|
||||
};
|
||||
|
||||
const hidePrompts = () => {
|
||||
setTimeout(() => {
|
||||
setShowPrompts(false);
|
||||
}, 50);
|
||||
|
||||
setTabbingIconIndex(0);
|
||||
};
|
||||
|
||||
const updateInputPrompt = (prompt: InputPrompt) => {
|
||||
hidePrompts();
|
||||
setMessage(`${prompt.content}`);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
@@ -152,6 +168,7 @@ export function ChatInputBar({
|
||||
!interactionsRef.current.contains(event.target as Node))
|
||||
) {
|
||||
hideSuggestions();
|
||||
hidePrompts();
|
||||
}
|
||||
};
|
||||
document.addEventListener("mousedown", handleClickOutside);
|
||||
@@ -181,10 +198,24 @@ export function ChatInputBar({
|
||||
}
|
||||
};
|
||||
|
||||
const handlePromptInput = (text: string) => {
|
||||
if (!text.startsWith("/")) {
|
||||
hidePrompts();
|
||||
} else {
|
||||
const promptMatch = text.match(/(?:\s|^)\/(\w*)$/);
|
||||
if (promptMatch) {
|
||||
setShowPrompts(true);
|
||||
} else {
|
||||
hidePrompts();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleInputChange = (event: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
const text = event.target.value;
|
||||
setMessage(text);
|
||||
handleAssistantInput(text);
|
||||
handlePromptInput(text);
|
||||
};
|
||||
|
||||
const assistantTagOptions = assistantOptions.filter((assistant) =>
|
||||
@@ -196,26 +227,49 @@ export function ChatInputBar({
|
||||
)
|
||||
);
|
||||
|
||||
const filteredPrompts = inputPrompts.filter(
|
||||
(prompt) =>
|
||||
prompt.active &&
|
||||
prompt.prompt.toLowerCase().startsWith(
|
||||
message
|
||||
.slice(message.lastIndexOf("/") + 1)
|
||||
.split(/\s/)[0]
|
||||
.toLowerCase()
|
||||
)
|
||||
);
|
||||
|
||||
const [tabbingIconIndex, setTabbingIconIndex] = useState(0);
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
if (
|
||||
showSuggestions &&
|
||||
assistantTagOptions.length > 0 &&
|
||||
((showSuggestions && assistantTagOptions.length > 0) || showPrompts) &&
|
||||
(e.key === "Tab" || e.key == "Enter")
|
||||
) {
|
||||
e.preventDefault();
|
||||
|
||||
if (tabbingIconIndex == assistantTagOptions.length && showSuggestions) {
|
||||
window.open("/assistants/new", "_self");
|
||||
if (
|
||||
(tabbingIconIndex == assistantTagOptions.length && showSuggestions) ||
|
||||
(tabbingIconIndex == filteredPrompts.length && showPrompts)
|
||||
) {
|
||||
if (showPrompts) {
|
||||
window.open("/prompts", "_self");
|
||||
} else {
|
||||
window.open("/assistants/new", "_self");
|
||||
}
|
||||
} else {
|
||||
const option =
|
||||
assistantTagOptions[tabbingIconIndex >= 0 ? tabbingIconIndex : 0];
|
||||
if (showPrompts) {
|
||||
const uppity =
|
||||
filteredPrompts[tabbingIconIndex >= 0 ? tabbingIconIndex : 0];
|
||||
updateInputPrompt(uppity);
|
||||
} else {
|
||||
const option =
|
||||
assistantTagOptions[tabbingIconIndex >= 0 ? tabbingIconIndex : 0];
|
||||
|
||||
updatedTaggedAssistant(option);
|
||||
updatedTaggedAssistant(option);
|
||||
}
|
||||
}
|
||||
}
|
||||
if (!showSuggestions) {
|
||||
if (!showPrompts && !showSuggestions) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -223,7 +277,10 @@ export function ChatInputBar({
|
||||
e.preventDefault();
|
||||
|
||||
setTabbingIconIndex((tabbingIconIndex) =>
|
||||
Math.min(tabbingIconIndex + 1, assistantTagOptions.length)
|
||||
Math.min(
|
||||
tabbingIconIndex + 1,
|
||||
showPrompts ? filteredPrompts.length : assistantTagOptions.length
|
||||
)
|
||||
);
|
||||
} else if (e.key === "ArrowUp") {
|
||||
e.preventDefault();
|
||||
@@ -284,6 +341,48 @@ export function ChatInputBar({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{showPrompts && (
|
||||
<div
|
||||
ref={suggestionsRef}
|
||||
className="text-sm absolute inset-x-0 top-0 w-full transform -translate-y-full"
|
||||
>
|
||||
<div className="rounded-lg py-1.5 bg-white border border-border-medium overflow-hidden shadow-lg mx-2 px-1.5 mt-2 rounded z-10">
|
||||
{filteredPrompts.map(
|
||||
(currentPrompt: InputPrompt, index: number) => (
|
||||
<button
|
||||
key={index}
|
||||
className={`px-2 ${
|
||||
tabbingIconIndex == index && "bg-hover"
|
||||
} rounded content-start flex gap-x-1 py-1.5 w-full hover:bg-hover cursor-pointer`}
|
||||
onClick={() => {
|
||||
updateInputPrompt(currentPrompt);
|
||||
}}
|
||||
>
|
||||
<p className="font-bold">{currentPrompt.prompt}:</p>
|
||||
<p className="text-left flex-grow mr-auto line-clamp-1">
|
||||
{currentPrompt.id == selectedAssistant.id &&
|
||||
"(default) "}
|
||||
{currentPrompt.content?.trim()}
|
||||
</p>
|
||||
</button>
|
||||
)
|
||||
)}
|
||||
|
||||
<a
|
||||
key={filteredPrompts.length}
|
||||
target="_self"
|
||||
className={`${
|
||||
tabbingIconIndex == filteredPrompts.length && "bg-hover"
|
||||
} px-3 flex gap-x-1 py-2 w-full items-center hover:bg-hover-light cursor-pointer"`}
|
||||
href="/prompts"
|
||||
>
|
||||
<FiPlus size={17} />
|
||||
<p>Create a new prompt</p>
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* <div>
|
||||
<SelectedFilterDisplay filterManager={filterManager} />
|
||||
</div> */}
|
||||
@@ -435,6 +534,7 @@ export function ChatInputBar({
|
||||
onKeyDown={(event) => {
|
||||
if (
|
||||
event.key === "Enter" &&
|
||||
!showPrompts &&
|
||||
!showSuggestions &&
|
||||
!event.shiftKey &&
|
||||
!(event.nativeEvent as any).isComposing
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import { Citation } from "@/components/search/results/Citation";
|
||||
import { WebResultIcon } from "@/components/WebResultIcon";
|
||||
import { LoadedDanswerDocument } from "@/lib/search/interfaces";
|
||||
import { getSourceMetadata, SOURCE_METADATA_MAP } from "@/lib/sources";
|
||||
import { getSourceMetadata } from "@/lib/sources";
|
||||
import { ValidSources } from "@/lib/types";
|
||||
import React, { memo } from "react";
|
||||
import isEqual from "lodash/isEqual";
|
||||
import { SlackIcon } from "@/components/icons/icons";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
|
||||
export const MemoizedAnchor = memo(
|
||||
({ docs, updatePresentingDocument, children }: any) => {
|
||||
@@ -21,9 +19,19 @@ export const MemoizedAnchor = memo(
|
||||
? new URL(associatedDoc.link).origin + "/favicon.ico"
|
||||
: "";
|
||||
|
||||
const icon = (
|
||||
<SourceIcon sourceType={associatedDoc?.source_type} iconSize={18} />
|
||||
);
|
||||
const getIcon = (sourceType: ValidSources, link: string) => {
|
||||
return getSourceMetadata(sourceType).icon({ size: 18 });
|
||||
};
|
||||
|
||||
const icon =
|
||||
associatedDoc?.source_type === "web" ? (
|
||||
<WebResultIcon url={associatedDoc.link} />
|
||||
) : (
|
||||
getIcon(
|
||||
associatedDoc?.source_type || "web",
|
||||
associatedDoc?.link || ""
|
||||
)
|
||||
);
|
||||
|
||||
return (
|
||||
<MemoizedLink
|
||||
|
||||
@@ -35,13 +35,6 @@ export function SetDefaultModelModal({
|
||||
const container = containerRef.current;
|
||||
const message = messageRef.current;
|
||||
|
||||
const handleEscape = (e: KeyboardEvent) => {
|
||||
if (e.key === "Escape") {
|
||||
onClose();
|
||||
}
|
||||
};
|
||||
window.addEventListener("keydown", handleEscape);
|
||||
|
||||
if (container && message) {
|
||||
const checkScrollable = () => {
|
||||
if (container.scrollHeight > container.clientHeight) {
|
||||
@@ -52,14 +45,9 @@ export function SetDefaultModelModal({
|
||||
};
|
||||
checkScrollable();
|
||||
window.addEventListener("resize", checkScrollable);
|
||||
return () => {
|
||||
window.removeEventListener("resize", checkScrollable);
|
||||
window.removeEventListener("keydown", handleEscape);
|
||||
};
|
||||
return () => window.removeEventListener("resize", checkScrollable);
|
||||
}
|
||||
|
||||
return () => window.removeEventListener("keydown", handleEscape);
|
||||
}, [onClose]);
|
||||
}, []);
|
||||
|
||||
const defaultModelDestructured = defaultModel
|
||||
? destructureValue(defaultModel)
|
||||
|
||||
@@ -31,6 +31,7 @@ export default async function Page(props: {
|
||||
openedFolders,
|
||||
defaultAssistantId,
|
||||
shouldShowWelcomeModal,
|
||||
userInputPrompts,
|
||||
ccPairs,
|
||||
} = data;
|
||||
|
||||
@@ -52,6 +53,7 @@ export default async function Page(props: {
|
||||
llmProviders,
|
||||
folders,
|
||||
openedFolders,
|
||||
userInputPrompts,
|
||||
shouldShowWelcomeModal,
|
||||
defaultAssistantId,
|
||||
}}
|
||||
|
||||
@@ -101,7 +101,7 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
flex-col relative
|
||||
h-screen
|
||||
transition-transform
|
||||
`}
|
||||
pt-2`}
|
||||
>
|
||||
<LogoType
|
||||
showArrow={true}
|
||||
@@ -163,6 +163,15 @@ export const HistorySidebar = forwardRef<HTMLDivElement, HistorySidebarProps>(
|
||||
Manage Assistants
|
||||
</p>
|
||||
</Link>
|
||||
<Link
|
||||
href="/prompts"
|
||||
className="w-full p-2 bg-white border-border border rounded items-center hover:bg-background-history-sidebar-button-hover cursor-pointer transition-all duration-150 flex gap-x-2"
|
||||
>
|
||||
<ClosedBookIcon className="h-4 w-4 my-auto text-text-history-sidebar-button" />
|
||||
<p className="my-auto flex items-center text-sm ">
|
||||
Manage Prompts
|
||||
</p>
|
||||
</Link>
|
||||
</div>
|
||||
)}
|
||||
<div className="border-b border-divider-history-sidebar-bar pb-4 mx-3" />
|
||||
|
||||
@@ -22,7 +22,7 @@ export default function FixedLogo({
|
||||
<>
|
||||
<Link
|
||||
href="/chat"
|
||||
className="fixed cursor-pointer flex z-40 left-4 top-2 h-8"
|
||||
className="fixed cursor-pointer flex z-40 left-2.5 top-2"
|
||||
>
|
||||
<div className="max-w-[200px] mobile:hidden flex items-center gap-x-1 my-auto">
|
||||
<div className="flex-none my-auto">
|
||||
@@ -46,8 +46,8 @@ export default function FixedLogo({
|
||||
</div>
|
||||
</div>
|
||||
</Link>
|
||||
<div className="mobile:hidden fixed left-4 bottom-4">
|
||||
<FiSidebar className="text-text-mobile-sidebar" />
|
||||
<div className="mobile:hidden fixed left-2.5 bottom-4">
|
||||
{/* <FiSidebar className="text-text-mobile-sidebar" /> */}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
export type FeedbackType = "like" | "dislike";
|
||||
export type ChatState =
|
||||
| "input"
|
||||
| "loading"
|
||||
| "streaming"
|
||||
| "toolBuilding"
|
||||
| "uploading";
|
||||
export type ChatState = "input" | "loading" | "streaming" | "toolBuilding";
|
||||
export interface RegenerationState {
|
||||
regenerating: boolean;
|
||||
finalMessageIndex: number;
|
||||
|
||||
@@ -1,43 +0,0 @@
|
||||
import { INTERNAL_URL } from "@/lib/constants";
|
||||
import { NextRequest, NextResponse } from "next/server";
|
||||
|
||||
// TODO: deprecate this and just go directly to the backend via /api/...
|
||||
// For some reason Egnyte doesn't work when using /api, so leaving this as is for now
|
||||
// If we do try and remove this, make sure we test the Egnyte connector oauth flow
|
||||
export async function GET(request: NextRequest) {
|
||||
try {
|
||||
const backendUrl = new URL(INTERNAL_URL);
|
||||
// Copy path and query parameters from incoming request
|
||||
backendUrl.pathname = request.nextUrl.pathname;
|
||||
backendUrl.search = request.nextUrl.search;
|
||||
|
||||
const response = await fetch(backendUrl, {
|
||||
method: "GET",
|
||||
headers: request.headers,
|
||||
body: request.body,
|
||||
signal: request.signal,
|
||||
// @ts-ignore
|
||||
duplex: "half",
|
||||
});
|
||||
|
||||
const responseData = await response.json();
|
||||
if (responseData.redirect_url) {
|
||||
return NextResponse.redirect(responseData.redirect_url);
|
||||
}
|
||||
|
||||
return new NextResponse(JSON.stringify(responseData), {
|
||||
status: response.status,
|
||||
headers: response.headers,
|
||||
});
|
||||
} catch (error: unknown) {
|
||||
console.error("Proxy error:", error);
|
||||
return NextResponse.json(
|
||||
{
|
||||
message: "Proxy error",
|
||||
error:
|
||||
error instanceof Error ? error.message : "An unknown error occurred",
|
||||
},
|
||||
{ status: 500 }
|
||||
);
|
||||
}
|
||||
}
|
||||
28
web/src/app/prompts/page.tsx
Normal file
28
web/src/app/prompts/page.tsx
Normal file
@@ -0,0 +1,28 @@
|
||||
import { fetchChatData } from "@/lib/chat/fetchChatData";
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { redirect } from "next/navigation";
|
||||
import WrappedPrompts from "../assistants/mine/WrappedInputPrompts";
|
||||
|
||||
export default async function GalleryPage(props: {
|
||||
searchParams: Promise<{ [key: string]: string }>;
|
||||
}) {
|
||||
const searchParams = await props.searchParams;
|
||||
noStore();
|
||||
|
||||
const data = await fetchChatData(searchParams);
|
||||
|
||||
if ("redirect" in data) {
|
||||
redirect(data.redirect);
|
||||
}
|
||||
|
||||
const { chatSessions, folders, openedFolders, toggleSidebar } = data;
|
||||
|
||||
return (
|
||||
<WrappedPrompts
|
||||
initiallyToggled={toggleSidebar}
|
||||
chatSessions={chatSessions}
|
||||
folders={folders}
|
||||
openedFolders={openedFolders}
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
export default function TemporaryLoadingModal({
|
||||
content,
|
||||
}: {
|
||||
content: string;
|
||||
}) {
|
||||
return (
|
||||
<div className="fixed inset-0 flex items-center justify-center z-50 bg-black bg-opacity-30">
|
||||
<div className="bg-white rounded-xl p-8 shadow-2xl flex items-center space-x-6">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-neutral-950"></div>
|
||||
<p className="text-xl font-medium text-gray-800">{content}</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -58,7 +58,7 @@ export function ClientLayout({
|
||||
return (
|
||||
<div className="h-screen overflow-y-hidden">
|
||||
<div className="flex h-full">
|
||||
<div className="flex-none text-text-settings-sidebar bg-background-sidebar w-[250px] overflow-x-hidden z-20 pt-2 pb-8 h-full border-r border-border miniscroll overflow-auto">
|
||||
<div className="flex-none text-text-settings-sidebar bg-background-sidebar w-[250px] z-20 pt-4 pb-8 h-full border-r border-border miniscroll overflow-auto">
|
||||
<AdminSidebar
|
||||
collections={[
|
||||
{
|
||||
@@ -169,6 +169,18 @@ export function ClientLayout({
|
||||
),
|
||||
link: "/admin/tools",
|
||||
},
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<ClosedBookIcon
|
||||
className="text-icon-settings-sidebar"
|
||||
size={18}
|
||||
/>
|
||||
<div className="ml-1">Prompt Library</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/prompt-library",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
...(enableEnterprise
|
||||
@@ -405,7 +417,7 @@ export function ClientLayout({
|
||||
/>
|
||||
</div>
|
||||
<div className="pb-8 relative h-full overflow-y-auto w-full">
|
||||
<div className="fixed left-0 gap-x-4 px-2 top-2 h-8 px-0 mb-auto w-full items-start flex justify-end">
|
||||
<div className="fixed bg-background left-0 gap-x-4 mb-8 px-4 py-2 w-full items-center flex justify-end">
|
||||
<UserDropdown />
|
||||
</div>
|
||||
<div className="pt-20 flex overflow-y-auto overflow-x-hidden h-full px-4 md:px-12">
|
||||
|
||||
@@ -49,7 +49,7 @@ export function AccessTypeForm({
|
||||
name: "Private",
|
||||
value: "private",
|
||||
description:
|
||||
"Only users who have explicitly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector",
|
||||
"Only users who have expliticly been given access to this connector (through the User Groups page) can access the documents pulled in by this connector",
|
||||
},
|
||||
];
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user